File tree Expand file tree Collapse file tree 3 files changed +37
-1
lines changed
compiler/src/iree/compiler/Dialect/Encoding/IR Expand file tree Collapse file tree 3 files changed +37
-1
lines changed Original file line number Diff line number Diff line change @@ -473,6 +473,21 @@ Attribute EncodingAttr::cloneWithLayouts(ArrayRef<Attribute> layouts) const {
473473 return LayoutAttr::get (ctx, ArrayAttr::get (ctx, layouts));
474474}
475475
476+ // ===---------------------------------------------------------------------===//
477+ // iree_encoding.matmul_k
478+ // ===---------------------------------------------------------------------===//
479+
480+ MatmulKAttr MatmulKAttr::get (MLIRContext *ctx, ArrayRef<int32_t > kDims ) {
481+ return get (ctx, DenseI32ArrayAttr::get (ctx, kDims ));
482+ }
483+
484+ bool MatmulKAttr::isSerialized () const { return false ; }
485+
486+ Attribute MatmulKAttr::cloneWithLayouts (ArrayRef<Attribute> layouts) const {
487+ MLIRContext *ctx = getContext ();
488+ return LayoutAttr::get (ctx, ArrayAttr::get (ctx, layouts));
489+ }
490+
476491// ===---------------------------------------------------------------------===//
477492// iree_encoding.pad_encoding_layout
478493// ===---------------------------------------------------------------------===//
Original file line number Diff line number Diff line change @@ -180,7 +180,12 @@ def EncodingAttr :
180180// iree_encoding.matmul_k
181181//===---------------------------------------------------------------------===//
182182
183- def MatmulKAttr : IREEEncoding_Attr<"MatmulK"> {
183+ def MatmulKAttr : IREEEncoding_Attr<"MatmulK", [
184+ DeclareAttrInterfaceMethods<IREEEncoding_SerializableEncodingAttrInterface, [
185+ "isSerialized",
186+ "cloneWithLayouts",
187+ ]>
188+ ]> {
184189 let mnemonic = "matmul_k";
185190 let summary = [{An attribute that tracks reduction dimensions for matmul}];
186191 let description = [{
@@ -223,6 +228,10 @@ def MatmulKAttr : IREEEncoding_Attr<"MatmulK"> {
223228 let parameters = (ins
224229 "DenseI32ArrayAttr":$k_dims
225230 );
231+
232+ let builders = [
233+ AttrBuilder<(ins "ArrayRef<int32_t>":$k_dims)>
234+ ];
226235}
227236
228237//===---------------------------------------------------------------------===//
Original file line number Diff line number Diff line change @@ -42,6 +42,18 @@ TEST_F(EncodingAttrsTest, EncodingAttr) {
4242 EXPECT_TRUE (attr.isIdentityLayout ());
4343}
4444
45+ TEST_F (EncodingAttrsTest, MatulKAttr) {
46+ MLIRContext *ctx = getContext ();
47+ Builder builder (ctx);
48+ auto attr = cast<SerializableEncodingAttrInterface>(
49+ MatmulKAttr::get (ctx, /* k_dims=*/ {1 }));
50+ EXPECT_FALSE (attr.isIdentityLayout ());
51+
52+ attr = cast<SerializableEncodingAttrInterface>(attr.cloneWithLayouts (
53+ PadEncodingLayoutAttr::getIdentityAttr (ctx, /* rank=*/ 2 )));
54+ EXPECT_TRUE (attr.isIdentityLayout ());
55+ }
56+
4557TEST_F (EncodingAttrsTest, PadEncodingLayoutAttr) {
4658 MLIRContext *ctx = getContext ();
4759 auto zeroPaddingAttr =
You can’t perform that action at this time.
0 commit comments