Skip to content

Commit c3ad3e5

Browse files
authored
[Encoding] Implement SerializableEncodingAttrInterface for MatmulK. (iree-org#20521)
It also adds a builder which takes `ArrayRef<int32_t>` as an input. It is a step towards iree-org#20493 Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent e72c3bf commit c3ad3e5

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,21 @@ Attribute EncodingAttr::cloneWithLayouts(ArrayRef<Attribute> layouts) const {
473473
return LayoutAttr::get(ctx, ArrayAttr::get(ctx, layouts));
474474
}
475475

476+
//===---------------------------------------------------------------------===//
477+
// iree_encoding.matmul_k
478+
//===---------------------------------------------------------------------===//
479+
480+
MatmulKAttr MatmulKAttr::get(MLIRContext *ctx, ArrayRef<int32_t> kDims) {
481+
return get(ctx, DenseI32ArrayAttr::get(ctx, kDims));
482+
}
483+
484+
bool MatmulKAttr::isSerialized() const { return false; }
485+
486+
Attribute MatmulKAttr::cloneWithLayouts(ArrayRef<Attribute> layouts) const {
487+
MLIRContext *ctx = getContext();
488+
return LayoutAttr::get(ctx, ArrayAttr::get(ctx, layouts));
489+
}
490+
476491
//===---------------------------------------------------------------------===//
477492
// iree_encoding.pad_encoding_layout
478493
//===---------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,12 @@ def EncodingAttr :
180180
// iree_encoding.matmul_k
181181
//===---------------------------------------------------------------------===//
182182

183-
def MatmulKAttr : IREEEncoding_Attr<"MatmulK"> {
183+
def MatmulKAttr : IREEEncoding_Attr<"MatmulK", [
184+
DeclareAttrInterfaceMethods<IREEEncoding_SerializableEncodingAttrInterface, [
185+
"isSerialized",
186+
"cloneWithLayouts",
187+
]>
188+
]> {
184189
let mnemonic = "matmul_k";
185190
let summary = [{An attribute that tracks reduction dimensions for matmul}];
186191
let description = [{
@@ -223,6 +228,10 @@ def MatmulKAttr : IREEEncoding_Attr<"MatmulK"> {
223228
let parameters = (ins
224229
"DenseI32ArrayAttr":$k_dims
225230
);
231+
232+
let builders = [
233+
AttrBuilder<(ins "ArrayRef<int32_t>":$k_dims)>
234+
];
226235
}
227236

228237
//===---------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Encoding/IR/unittests/EncodingAttrTest.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ TEST_F(EncodingAttrsTest, EncodingAttr) {
4242
EXPECT_TRUE(attr.isIdentityLayout());
4343
}
4444

45+
TEST_F(EncodingAttrsTest, MatulKAttr) {
46+
MLIRContext *ctx = getContext();
47+
Builder builder(ctx);
48+
auto attr = cast<SerializableEncodingAttrInterface>(
49+
MatmulKAttr::get(ctx, /*k_dims=*/{1}));
50+
EXPECT_FALSE(attr.isIdentityLayout());
51+
52+
attr = cast<SerializableEncodingAttrInterface>(attr.cloneWithLayouts(
53+
PadEncodingLayoutAttr::getIdentityAttr(ctx, /*rank=*/2)));
54+
EXPECT_TRUE(attr.isIdentityLayout());
55+
}
56+
4557
TEST_F(EncodingAttrsTest, PadEncodingLayoutAttr) {
4658
MLIRContext *ctx = getContext();
4759
auto zeroPaddingAttr =

0 commit comments

Comments
 (0)