2020
2121namespace mlir ::iree_compiler::IREE {
2222
23- using IREE::Codegen::MaterializeEncodingInfo;
24- using IREE::TensorExt::DispatchTensorType;
25-
2623static const char kEncodingInfoAttrName [] = " encoding_info" ;
2724
2825// This class is the base class for the external model of different packed
@@ -33,18 +30,18 @@ static const char kEncodingInfoAttrName[] = "encoding_info";
3330template <typename DeviceEncodingPackedLayoutAttrInterface,
3431 typename EncodingLayoutAttr>
3532struct DevicePackedLayoutAttrExternalModelBase
36- : public Codegen::PackedLayoutAttrInterface::ExternalModel<
33+ : public IREE:: Codegen::PackedLayoutAttrInterface::ExternalModel<
3734 DeviceEncodingPackedLayoutAttrInterface, EncodingLayoutAttr> {
3835public:
39- Codegen::MaterializeEncodingInfo
36+ IREE:: Codegen::MaterializeEncodingInfo
4037 getEncodingInfo (Attribute attr, RankedTensorType type) const {
4138 const DeviceEncodingPackedLayoutAttrInterface *impl =
4239 static_cast <const DeviceEncodingPackedLayoutAttrInterface *>(this );
4340 // If the layout is already resolved, use it directly.
4441 if (auto config = impl->getConfiguration (attr)) {
4542 if (auto namedAttr = config.getNamed (kEncodingInfoAttrName )) {
46- std::optional<Codegen::MaterializeEncodingInfo> info =
47- Codegen::deserializeEncodingInfo (
43+ std::optional<IREE:: Codegen::MaterializeEncodingInfo> info =
44+ IREE:: Codegen::deserializeEncodingInfo (
4845 cast<DictionaryAttr>(namedAttr->getValue ()));
4946 assert (info && " encoding_info is invalid" );
5047 return info.value ();
@@ -60,8 +57,8 @@ struct DeviceEncodingLayoutAttrInterfaceExternalModelBase
6057 : public IREE::Encoding::LayoutAttrInterface::ExternalModel<
6158 DeviceEncodingLayoutAttrInterface, EncodingLayoutAttr> {
6259public:
63- MaterializeEncodingInfo getEncodingInfo (EncodingLayoutAttr layoutAttr,
64- RankedTensorType type) const {
60+ IREE::Codegen:: MaterializeEncodingInfo
61+ getEncodingInfo (EncodingLayoutAttr layoutAttr, RankedTensorType type) const {
6562 return getEncodingInfoFromLayout (
6663 type, cast<IREE::Encoding::LayoutAttrInterface>(layoutAttr));
6764 }
@@ -73,7 +70,7 @@ struct DeviceEncodingLayoutAttrInterfaceExternalModelBase
7370 // For a given tensor type with an encoding, return the materialized
7471 // type to use for it. If no encoding is set, then return the tensor
7572 // type itself.
76- MaterializeEncodingInfo encodingInfo =
73+ IREE::Codegen:: MaterializeEncodingInfo encodingInfo =
7774 getEncodingInfo (layoutAttr, type);
7875 if (IREE::Codegen::isIdentityLayout (encodingInfo)) {
7976 return type.dropEncoding ();
@@ -100,15 +97,16 @@ struct DeviceEncodingLayoutAttrInterfaceExternalModelBase
10097 newShape.append (swizzledTileShape);
10198 return RankedTensorType::get (newShape, packedType.getElementType ());
10299 })
103- .template Case <DispatchTensorType>([&](auto dispatchTensorType) {
104- Type boundType = dispatchTensorType.getBoundType ();
105- Type convertedBoundType = convertType (attr, boundType);
106- if (convertedBoundType == boundType) {
107- return dispatchTensorType;
108- }
109- return DispatchTensorType::get (dispatchTensorType.getAccess (),
110- convertedBoundType);
111- })
100+ .template Case <IREE::TensorExt::DispatchTensorType>(
101+ [&](auto dispatchTensorType) {
102+ Type boundType = dispatchTensorType.getBoundType ();
103+ Type convertedBoundType = convertType (attr, boundType);
104+ if (convertedBoundType == boundType) {
105+ return dispatchTensorType;
106+ }
107+ return IREE::TensorExt::DispatchTensorType::get (
108+ dispatchTensorType.getAccess (), convertedBoundType);
109+ })
112110 .Default ([&](auto concreteType) { return concreteType; });
113111 }
114112
@@ -127,7 +125,7 @@ struct DeviceEncodingLayoutAttrInterfaceExternalModelBase
127125 return failure ();
128126 }
129127 auto boundTensorType = cast<RankedTensorType>(type.getBoundType ());
130- MaterializeEncodingInfo encodingInfo =
128+ IREE::Codegen:: MaterializeEncodingInfo encodingInfo =
131129 getEncodingInfoFromLayout (boundTensorType, layoutAttr);
132130 newSizes = getMixedValues (boundTensorType.getShape (), dynamicDims, builder);
133131 FailureOr<SmallVector<OpFoldResult>> convertedMixedSizes =
0 commit comments