Skip to content

Commit 5de0f06

Browse files
authored
[Codegen] Implement serialization for MaterializeEncodingInfo struct. (iree-org#19260)
The revision adds the support for conversion between MaterializeEncodingInfo struct and DictionaryAttr. It also implements the in(equality) operators for the struct to verify if the deserialized result match the original struct. Signed-off-by: hanhanW <[email protected]>
1 parent a67b00b commit 5de0f06

File tree

3 files changed

+161
-0
lines changed

3 files changed

+161
-0
lines changed

compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
7272
return os;
7373
}
7474

75+
bool operator==(const MaterializeEncodingInfo &lhs,
76+
const MaterializeEncodingInfo &rhs) {
77+
return lhs.innerDimsPos == rhs.innerDimsPos &&
78+
lhs.innerTileSizes == rhs.innerTileSizes &&
79+
lhs.outerDimsPerm == rhs.outerDimsPerm && lhs.swizzle == rhs.swizzle;
80+
}
81+
82+
bool operator!=(const MaterializeEncodingInfo &lhs,
83+
const MaterializeEncodingInfo &rhs) {
84+
return !(lhs == rhs);
85+
}
86+
7587
//===----------------------------------------------------------------------===//
7688
// Layout Utilities.
7789
//===----------------------------------------------------------------------===//
@@ -188,6 +200,57 @@ std::optional<TileSwizzle> deserializeTileSwizzle(DictionaryAttr attr) {
188200
return swizzle;
189201
}
190202

203+
DictionaryAttr serializeEncodingInfo(MLIRContext *ctx,
204+
const MaterializeEncodingInfo &info) {
205+
Builder b(ctx);
206+
SmallVector<NamedAttribute> items;
207+
items.emplace_back(b.getStringAttr("innerDimsPos"),
208+
b.getI64ArrayAttr(info.innerDimsPos));
209+
items.emplace_back(b.getStringAttr("innerTileSizes"),
210+
b.getI64ArrayAttr(info.innerTileSizes));
211+
items.emplace_back(b.getStringAttr("outerDimsPerm"),
212+
b.getI64ArrayAttr(info.outerDimsPerm));
213+
if (info.swizzle) {
214+
items.emplace_back(b.getStringAttr("swizzle"),
215+
serializeTileSwizzle(ctx, info.swizzle.value()));
216+
}
217+
218+
return b.getDictionaryAttr(items);
219+
}
220+
221+
std::optional<MaterializeEncodingInfo>
222+
deserializeEncodingInfo(DictionaryAttr attr) {
223+
MaterializeEncodingInfo info;
224+
225+
#define extractArrayAttrItem(name) \
226+
{ \
227+
auto value = attr.getNamed(#name); \
228+
if (!value || !isa<ArrayAttr>(value->getValue())) { \
229+
return std::nullopt; \
230+
} \
231+
info.name = extractFromIntegerArrayAttr<int64_t>(value->getValue()); \
232+
}
233+
234+
extractArrayAttrItem(innerDimsPos);
235+
extractArrayAttrItem(innerTileSizes);
236+
extractArrayAttrItem(outerDimsPerm);
237+
#undef extractArrayAttrItem
238+
239+
if (attr.contains("swizzle")) {
240+
auto dictAttr =
241+
dyn_cast<DictionaryAttr>(attr.getNamed("swizzle")->getValue());
242+
if (!dictAttr) {
243+
return std::nullopt;
244+
}
245+
info.swizzle = deserializeTileSwizzle(dictAttr);
246+
if (!info.swizzle) {
247+
return std::nullopt;
248+
}
249+
}
250+
251+
return info;
252+
}
253+
191254
SmallVector<int64_t>
192255
getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) {
193256
SmallVector<int64_t> result;

compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ struct MaterializeEncodingInfo {
107107
std::optional<TileSwizzle> swizzle;
108108
};
109109

110+
bool operator==(const MaterializeEncodingInfo &lhs,
111+
const MaterializeEncodingInfo &rhs);
112+
bool operator!=(const MaterializeEncodingInfo &lhs,
113+
const MaterializeEncodingInfo &rhs);
114+
110115
//===----------------------------------------------------------------------===//
111116
// Layout Utilities.
112117
//===----------------------------------------------------------------------===//
@@ -120,6 +125,12 @@ DictionaryAttr serializeTileSwizzle(MLIRContext *ctx,
120125
const TileSwizzle &swizzle);
121126
std::optional<TileSwizzle> deserializeTileSwizzle(DictionaryAttr attr);
122127

128+
/// Conversion between MaterializeEncodingInfo struct and DictionaryAttr.
129+
DictionaryAttr serializeEncodingInfo(MLIRContext *ctx,
130+
const MaterializeEncodingInfo &info);
131+
std::optional<MaterializeEncodingInfo>
132+
deserializeEncodingInfo(DictionaryAttr attr);
133+
123134
/// Concatenates the vectors.
124135
SmallVector<int64_t>
125136
getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape);

compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,5 +99,92 @@ TEST(TileSwizzle, Deserialization) {
9999
EXPECT_FALSE(deserializeTileSwizzle(b.getDictionaryAttr(items)).has_value());
100100
}
101101

102+
TEST(MaterializeEncodingInfo, RelationalOperator) {
103+
MaterializeEncodingInfo info1;
104+
info1.innerDimsPos = {0, 1};
105+
info1.innerTileSizes = {16, 1};
106+
info1.outerDimsPerm = {0, 2, 1, 3};
107+
108+
MaterializeEncodingInfo info2;
109+
info2.innerDimsPos = {1, 0};
110+
info2.innerTileSizes = {16, 1};
111+
info2.outerDimsPerm = {0, 2, 1, 3};
112+
113+
EXPECT_EQ(info1, info1);
114+
EXPECT_EQ(info2, info2);
115+
EXPECT_NE(info1, info2);
116+
117+
// They mismatch if one has a swizzle, but not the other.
118+
info2 = info1;
119+
info1.swizzle = TileSwizzle();
120+
EXPECT_NE(info1, info2);
121+
122+
// They match because they all have an empty swizzle.
123+
info2.swizzle = TileSwizzle();
124+
EXPECT_EQ(info1, info2);
125+
}
126+
127+
TEST(MaterializeEncodingInfo, Serialization) {
128+
MaterializeEncodingInfo info;
129+
info.innerDimsPos = {0, 1};
130+
info.innerTileSizes = {16, 1};
131+
info.outerDimsPerm = {0, 2, 1, 3};
132+
133+
MLIRContext ctx;
134+
DictionaryAttr dictAttr = serializeEncodingInfo(&ctx, info);
135+
136+
EXPECT_TRUE(dictAttr.contains("innerDimsPos"));
137+
EXPECT_TRUE(dictAttr.contains("innerTileSizes"));
138+
EXPECT_TRUE(dictAttr.contains("outerDimsPerm"));
139+
EXPECT_FALSE(dictAttr.contains("swizzle"));
140+
141+
EXPECT_TRUE(isa<ArrayAttr>(dictAttr.getNamed("innerDimsPos")->getValue()));
142+
EXPECT_TRUE(isa<ArrayAttr>(dictAttr.getNamed("innerTileSizes")->getValue()));
143+
EXPECT_TRUE(isa<ArrayAttr>(dictAttr.getNamed("outerDimsPerm")->getValue()));
144+
145+
auto extractedInnerDimsPos = extractFromIntegerArrayAttr<int64_t>(
146+
dictAttr.getNamed("innerDimsPos")->getValue());
147+
EXPECT_EQ(extractedInnerDimsPos, info.innerDimsPos);
148+
auto extractedInnerTileSizes = extractFromIntegerArrayAttr<int64_t>(
149+
dictAttr.getNamed("innerTileSizes")->getValue());
150+
EXPECT_EQ(extractedInnerTileSizes, info.innerTileSizes);
151+
auto extractedOuterDimsPerm = extractFromIntegerArrayAttr<int64_t>(
152+
dictAttr.getNamed("outerDimsPerm")->getValue());
153+
EXPECT_EQ(extractedOuterDimsPerm, info.outerDimsPerm);
154+
155+
std::optional<MaterializeEncodingInfo> deserializedInfo =
156+
deserializeEncodingInfo(dictAttr);
157+
EXPECT_THAT(deserializedInfo, Optional(info));
158+
}
159+
160+
TEST(MaterializeEncodingInfo, Deserialization) {
161+
MLIRContext ctx;
162+
Builder b(&ctx);
163+
164+
auto emptyDictAttr = b.getDictionaryAttr({});
165+
EXPECT_FALSE(deserializeEncodingInfo(emptyDictAttr).has_value());
166+
167+
SmallVector<NamedAttribute> items;
168+
items.emplace_back(b.getStringAttr("innerDimsPos"),
169+
b.getI64ArrayAttr({0, 1}));
170+
EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());
171+
172+
items.emplace_back(b.getStringAttr("innerTileSizes"),
173+
b.getI64ArrayAttr({16, 1}));
174+
EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());
175+
176+
items.emplace_back(b.getStringAttr("outerDimsPerm"),
177+
b.getI64ArrayAttr({0, 2, 1, 3}));
178+
EXPECT_TRUE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());
179+
180+
// If the swizzle presents, it needs to be deserializable to TileSwizzle.
181+
items.emplace_back(b.getStringAttr("swizzle"), b.getUnitAttr());
182+
EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());
183+
184+
TileSwizzle swizzle;
185+
items.back().setValue(serializeTileSwizzle(&ctx, swizzle));
186+
EXPECT_TRUE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());
187+
}
188+
102189
} // namespace
103190
} // namespace mlir::iree_compiler::IREE::Codegen

0 commit comments

Comments
 (0)