Skip to content

Commit de56220

Browse files
authored
[Encoding] Introduce matmul_k encoding. (#20484)
The attribute is specialized for tracking the matmul reduction dimensions only. The encoded dimensions are the reduction dimensions that will be used by matmuls. The below message will be removed before we land the PR. Please check #20485, if you want to get the full picture about how the new encoding is set up. --------- Signed-off-by: hanhanW <[email protected]>
1 parent b91405d commit de56220

File tree

3 files changed

+67
-7
lines changed

3 files changed

+67
-7
lines changed

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,55 @@ def EncodingAttr :
138138
let genVerifyDecl = 1;
139139
}
140140

141+
//===---------------------------------------------------------------------===//
142+
// encoding.matmul_k
143+
//===---------------------------------------------------------------------===//
144+
145+
def MatmulKAttr : IREEEncoding_Attr<"MatmulK"> {
146+
let mnemonic = "matmul_k";
147+
let summary = [{An attribute that tracks reduction dimensions for matmul}];
148+
let description = [{
149+
The attribute is specialized for tracking the matmul reduction dimensions
150+
only. The encoded dimensions are the reduction dimensions that will be
151+
further used in a matmul.
152+
153+
Example: linalg.matmul ins(%lhs, %rhs) outs(%acc), where the indexing maps
154+
are:
155+
156+
lhs: (M, N, K) -> (M, K)
157+
rhs: (M, N, K) -> (K, N)
158+
acc: (M, N, K) -> (M, N)
159+
160+
The encoding for the lhs is iree_encoding.matmul_k<k_dims = [1]>.
161+
The encoding for the rhs is iree_encoding.matmul_k<k_dims = [0]>.
162+
The encoding for the acc is iree_encoding.matmul_k<k_dims = []>.
163+
164+
Example: linalg.matmul_transpose_b ins(%lhs, %rhs) outs(%acc), where the
165+
indexing maps are:
166+
167+
lhs: (M, N, K) -> (M, K)
168+
rhs: (M, N, K) -> (N, K)
169+
acc: (M, N, K) -> (M, N)
170+
171+
The encoding for the lhs is iree_encoding.matmul_k<k_dims = [1]>.
172+
The encoding for the rhs is iree_encoding.matmul_k<k_dims = [1]>.
173+
The encoding for the acc is iree_encoding.matmul_k<k_dims = []>.
174+
175+
There is no difference between static cases and dynamic cases because the
176+
dimension sizes are not encoded.
177+
178+
Any encoding propagation that transforms reduction dimensions could result
179+
in undefined behavior, because it does not encode the transformations. The
180+
information is missing in this context.
181+
}];
182+
183+
let assemblyFormat = "`<` struct(params) `>`";
184+
185+
let parameters = (ins
186+
"DenseI32ArrayAttr":$k_dims
187+
);
188+
}
189+
141190
//===---------------------------------------------------------------------===//
142191
// encoding.pad_encoding_layout
143192
//===---------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingDialect.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ namespace {
3030
// Used for custom printing support.
3131
struct EncodingOpAsmInterface : public OpAsmDialectInterface {
3232
using OpAsmDialectInterface::OpAsmDialectInterface;
33-
/// Hooks for getting an alias identifier alias for a given symbol, that is
34-
/// not necessarily a part of this dialect. The identifier is used in place
35-
/// of the symbol when printing textual IR. These aliases must not contain
36-
/// `.` or end with a numeric digit([0-9]+). Returns success if an alias was
37-
/// provided, failure otherwise.
33+
// Hooks for getting an alias identifier alias for a given symbol, that is
34+
// not necessarily a part of this dialect. The identifier is used in place
35+
// of the symbol when printing textual IR. These aliases must not contain
36+
// `.` or end with a numeric digit([0-9]+). Returns success if an alias was
37+
// provided, failure otherwise.
3838
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
39-
if (llvm::isa<EncodingAttr, TestingEncodingAttr, UnknownEncodingAttr>(
40-
attr)) {
39+
if (llvm::isa<EncodingAttr, MatmulKAttr, TestingEncodingAttr,
40+
UnknownEncodingAttr>(attr)) {
4141
os << "encoding";
4242
return AliasResult::OverridableAlias;
4343
}

compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,14 @@ func.func @testing_encoding_with_layouts(%arg0: tensor<?x?xf32, #encoding>) -> t
243243
// CHECK: func.func @testing_encoding_with_layouts(
244244
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32, #[[ENCODING]]>
245245
// CHECK return %[[ARG0]]
246+
247+
// -----
248+
249+
#encoding = #iree_encoding.matmul_k<k_dims = [1]>
250+
func.func @matmul_k_encoding(%arg0: tensor<?x?xf32, #encoding>) -> tensor<?x?xf32, #encoding> {
251+
return %arg0 : tensor<?x?xf32, #encoding>
252+
}
253+
// CHECK: #[[ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
254+
// CHECK: func.func @matmul_k_encoding(
255+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32, #[[ENCODING]]>
256+
// CHECK return %[[ARG0]]

0 commit comments

Comments
 (0)