Skip to content

Commit a802470

Browse files
authored
[DataTiling] Add matmul_k option to SetEncoding pass. (iree-org#20529)
1 parent 53e15a4 commit a802470

File tree

6 files changed

+143
-63
lines changed

6 files changed

+143
-63
lines changed

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Pass/PassRegistry.h"
1717
#include "mlir/Transforms/Passes.h"
1818

19+
namespace mlir::iree_compiler::DispatchCreation {
1920
//===----------------------------------------------------------------------===//
2021
// Command Line Options
2122
//===----------------------------------------------------------------------===//
@@ -79,7 +80,16 @@ static llvm::cl::opt<bool> clHoistEncodingsForConstExpr(
7980
"--iree-opt-data-tiling=false must be set as wells"),
8081
llvm::cl::init(true));
8182

82-
namespace mlir::iree_compiler::DispatchCreation {
83+
static llvm::cl::opt<DispatchCreation::EncodingOptions> clSetEncodingStrategy(
84+
"iree-dispatch-creation-set-encoding-strategy",
85+
llvm::cl::desc("Set the encoding strategy for operations."),
86+
llvm::cl::values(
87+
clEnumValN(
88+
DispatchCreation::EncodingOptions::Generic, "generic",
89+
"Using EncodingAttr which encodes as much information as possible"),
90+
clEnumValN(DispatchCreation::EncodingOptions::MatmulK, "matmulk",
91+
"Only encodes the reduction dimenesions in the encoding.")),
92+
llvm::cl::init(DispatchCreation::EncodingOptions::Generic));
8393

8494
//===----------------------------------------------------------------------===//
8595
// Utilities
@@ -244,7 +254,10 @@ addDispatchRegionCreationPasses(OpPassManager &passManager,
244254
// Set encodings on all eligible ops. All ops should be in compiler
245255
// formed dispatch regions, so encodings will be placed inside of the
246256
// dispatch regions with the data-tiled op.
247-
.addPass(createSetEncodingPass)
257+
.addPass([&]() {
258+
return DispatchCreation::createSetEncodingPass(
259+
DispatchCreation::SetEncodingPassOptions{clSetEncodingStrategy});
260+
})
248261
// SetEncodingOps should not be in the same dispatch as the data-tiled
249262
// op, so hoist them out of their current dispatch regions. Also, bubble
250263
// SetEncodingOps through special operations like bit-extending ops and

compiler/src/iree/compiler/DispatchCreation/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
namespace mlir::iree_compiler::DispatchCreation {
1919

20+
enum class EncodingOptions { MatmulK, Generic };
21+
2022
//===----------------------------------------------------------------------===//
2123
// Pipelines
2224
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/DispatchCreation/Passes.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,29 @@ def PropagateEncodingsPass :
306306
];
307307
}
308308

309-
def SetEncodingPass :
310-
InterfacePass<"iree-dispatch-creation-set-encoding", "mlir::FunctionOpInterface"> {
309+
def SetEncodingPass : InterfacePass<"iree-dispatch-creation-set-encoding",
310+
"mlir::FunctionOpInterface"> {
311311
let summary = "Introduces tensor encoding for flow dispatch regions.";
312312
let dependentDialects = [
313313
"mlir::linalg::LinalgDialect",
314314
"IREE::Flow::FlowDialect",
315315
"IREE::Encoding::IREEEncodingDialect",
316316
];
317+
let options = [
318+
Option<
319+
"encodingOption", "encoding-option",
320+
"mlir::iree_compiler::DispatchCreation::EncodingOptions",
321+
/*default=*/
322+
"mlir::iree_compiler::DispatchCreation::EncodingOptions::Generic",
323+
"Select the type of encoding options to add.",
324+
[{::llvm::cl::values(
325+
clEnumValN(
326+
mlir::iree_compiler::DispatchCreation::EncodingOptions::MatmulK,
327+
"matmulk", "Only encodes reduction dimensions in the encoding."),
328+
clEnumValN(
329+
mlir::iree_compiler::DispatchCreation::EncodingOptions::Generic,
330+
"default", "Uses EncodingAttr which encodes as much information as possible."))}]>,
331+
];
317332
}
318333

319334
def ConvertEncodingToFlowPass :

compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ namespace mlir::iree_compiler::DispatchCreation {
3131
#include "iree/compiler/DispatchCreation/Passes.h.inc"
3232

3333
using IREE::Encoding::EncodingAttr;
34+
using IREE::Encoding::MatmulKAttr;
3435

3536
//===---------------------------------------------------------------------===//
3637
// Utility functions
3738
//===---------------------------------------------------------------------===//
3839

39-
Value setEncoding(OpBuilder &builder, Location loc, Value source,
40-
EncodingAttr encodingAttr) {
41-
auto sourceType = cast<RankedTensorType>(source.getType());
42-
auto resultType = RankedTensorType::get(
43-
sourceType.getShape(), sourceType.getElementType(), encodingAttr);
40+
static Value setEncoding(OpBuilder &builder, Location loc, Value source,
41+
Attribute encodingAttr) {
42+
auto resultType =
43+
cast<RankedTensorType>(source.getType()).cloneWithEncoding(encodingAttr);
4444
return builder.create<IREE::Encoding::SetEncodingOp>(loc, resultType, source);
4545
};
4646

@@ -163,11 +163,13 @@ class SetContractionOpEncoding final
163163
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
164164
public:
165165
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
166-
explicit SetContractionOpEncoding(MLIRContext *ctx)
167-
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx) {}
166+
explicit SetContractionOpEncoding(MLIRContext *ctx, EncodingOptions &option)
167+
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx),
168+
encodingOption(option) {}
168169

169170
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
170171
PatternRewriter &rewriter) const override {
172+
171173
if (!linalgOp.hasPureTensorSemantics()) {
172174
return failure();
173175
}
@@ -228,14 +230,39 @@ class SetContractionOpEncoding final
228230

229231
auto opType = IREE::Encoding::EncodingOpType::matmul;
230232
auto setEncodingWrapper = [&](Value src, int64_t operandIndex) -> Value {
231-
auto encoding =
232-
EncodingAttr::get(linalgOp.getContext(), operandIndex, opType,
233-
elemTypes, maps, iterationSizes);
233+
MLIRContext *ctx = linalgOp.getContext();
234+
Attribute encoding;
235+
switch (encodingOption) {
236+
case EncodingOptions::Generic: {
237+
encoding = EncodingAttr::get(ctx, operandIndex, opType, elemTypes, maps,
238+
iterationSizes);
239+
break;
240+
}
241+
case EncodingOptions::MatmulK: {
242+
SmallVector<int32_t> kDims;
243+
AffineMap indexingMap = maps[operandIndex];
244+
auto cDims = linalg::inferContractionDims(linalgOp);
245+
for (auto k : cDims->k) {
246+
std::optional<unsigned> dimIdx =
247+
indexingMap.getResultPosition(rewriter.getAffineDimExpr(k));
248+
if (!dimIdx) {
249+
continue;
250+
}
251+
kDims.push_back(dimIdx.value());
252+
}
253+
encoding = MatmulKAttr::get(ctx, kDims);
254+
break;
255+
}
256+
default: {
257+
assert(false && "Unsupported encoding option");
258+
return Value();
259+
}
260+
}
234261
return setEncoding(rewriter, loc, src, encoding);
235262
};
236-
Value encodedLhs = setEncodingWrapper(lhs, IREE::Encoding::MATMUL_LHS);
237-
Value encodedRhs = setEncodingWrapper(rhs, IREE::Encoding::MATMUL_RHS);
238-
Value encodedOut = setEncodingWrapper(out, IREE::Encoding::MATMUL_RESULT);
263+
auto encodedLhs = setEncodingWrapper(lhs, IREE::Encoding::MATMUL_LHS);
264+
auto encodedRhs = setEncodingWrapper(rhs, IREE::Encoding::MATMUL_RHS);
265+
auto encodedOut = setEncodingWrapper(out, IREE::Encoding::MATMUL_RESULT);
239266
Value opTiled = clone(rewriter, linalgOp, encodedOut.getType(),
240267
ValueRange{encodedLhs, encodedRhs, encodedOut})
241268
->getResult(0);
@@ -248,6 +275,9 @@ class SetContractionOpEncoding final
248275
rewriter.replaceOp(linalgOp, result);
249276
return success();
250277
}
278+
279+
private:
280+
EncodingOptions encodingOption;
251281
};
252282

253283
/// Pattern to fold a `linalg.fill` -> `iree_encoding.set_encoding`
@@ -281,7 +311,7 @@ struct SetEncodingPass final : impl::SetEncodingPassBase<SetEncodingPass> {
281311
void runOnOperation() override {
282312
MLIRContext *context = &getContext();
283313
RewritePatternSet patterns(context);
284-
patterns.add<SetContractionOpEncoding>(context);
314+
patterns.add<SetContractionOpEncoding>(context, encodingOption.getValue());
285315
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
286316
patterns.add<FoldFillWithSetEncoding>(context);
287317
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);

compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,36 @@
1-
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-encoding))" %s | FileCheck %s
1+
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-encoding))" %s | FileCheck %s --check-prefixes=CHECK-ALL,CHECK
2+
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-encoding{encoding-option=matmulk}))" %s | FileCheck %s --check-prefixes=CHECK-ALL,MATMULK
23

34
util.func public @matmul_f32f32f32(%arg0 : tensor<100x250xf32>, %arg1 : tensor<250x500xf32>,
45
%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
56
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<100x250xf32>, tensor<250x500xf32>)
67
outs(%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32>
78
util.return %0 : tensor<100x500xf32>
89
}
9-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
10-
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
11-
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
12-
// CHECK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
13-
// CHECK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
14-
// CHECK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
15-
// CHECK: util.func public @matmul_f32f32f32(
16-
// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
17-
// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
18-
// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
19-
// CHECK: %[[LHS:.+]] = iree_encoding.set_encoding %[[ARG0]]
20-
// CHECK-SAME: tensor<100x250xf32, #[[LHS_ENCODING]]>
21-
// CHECK: %[[RHS:.+]] = iree_encoding.set_encoding %[[ARG1]]
22-
// CHECK-SAME: tensor<250x500xf32, #[[RHS_ENCODING]]>
23-
// CHECK: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
24-
// CHECK-SAME: tensor<100x500xf32, #[[OUT_ENCODING]]>
25-
// CHECK: %[[MATMUL:.+]] = linalg.matmul
26-
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
27-
// CHECK-SAME: outs(%[[OUTS]] :
28-
// CHECK: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<100x500xf32, #[[OUT_ENCODING]]> -> tensor<100x500xf32>
29-
// CHECK: util.return %[[RESULT]]
10+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
11+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
12+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
13+
// CHECK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
14+
// CHECK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
15+
// CHECK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
16+
// MATMULK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
17+
// MATMULK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [0]>
18+
// MATMULK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = []>
19+
// CHECK-ALL: util.func public @matmul_f32f32f32(
20+
// CHECK-ALL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
21+
// CHECK-ALL-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
22+
// CHECK-ALL-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
23+
// CHECK-ALL: %[[LHS:.+]] = iree_encoding.set_encoding %[[ARG0]]
24+
// CHECK-ALL-SAME: tensor<100x250xf32, #[[LHS_ENCODING]]>
25+
// CHECK-ALL: %[[RHS:.+]] = iree_encoding.set_encoding %[[ARG1]]
26+
// CHECK-ALL-SAME: tensor<250x500xf32, #[[RHS_ENCODING]]>
27+
// CHECK-ALL: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
28+
// CHECK-ALL-SAME: tensor<100x500xf32, #[[OUT_ENCODING]]>
29+
// CHECK-ALL: %[[MATMUL:.+]] = linalg.matmul
30+
// CHECK-ALL-SAME: ins(%[[LHS]], %[[RHS]] :
31+
// CHECK-ALL-SAME: outs(%[[OUTS]] :
32+
// CHECK-ALL: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<100x500xf32, #[[OUT_ENCODING]]> -> tensor<100x500xf32>
33+
// CHECK-ALL: util.return %[[RESULT]]
3034

3135
// -----
3236

@@ -72,27 +76,30 @@ util.func public @matmul_f32f32f32_parallel_reduce_parallel(%arg0 : tensor<32x12
7276
} -> tensor<4096x32xf32>
7377
util.return %0 : tensor<4096x32xf32>
7478
}
75-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
76-
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
77-
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
78-
// CHECK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
79-
// CHECK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
80-
// CHECK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
81-
// CHECK: util.func public @matmul_f32f32f32_parallel_reduce_parallel(
82-
// CHECK-SAME: %[[ARG0:.+]]: tensor<32x128xf32>
83-
// CHECK-SAME: %[[ARG1:.+]]: tensor<128x4096xf32>
84-
// CHECK-SAME: %[[ARG2:.+]]: tensor<4096x32xf32>
85-
// CHECK: %[[LHS:.+]] = iree_encoding.set_encoding %[[ARG0]]
86-
// CHECK-SAME: tensor<32x128xf32, #[[LHS_ENCODING]]>
87-
// CHECK: %[[RHS:.+]] = iree_encoding.set_encoding %[[ARG1]]
88-
// CHECK-SAME: tensor<128x4096xf32, #[[RHS_ENCODING]]>
89-
// CHECK: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
90-
// CHECK-SAME: tensor<4096x32xf32, #[[OUT_ENCODING]]>
91-
// CHECK: %[[MATMUL:.+]] = linalg.generic
92-
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
93-
// CHECK-SAME: outs(%[[OUTS]] :
94-
// CHECK: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<4096x32xf32, #[[OUT_ENCODING]]> -> tensor<4096x32xf32>
95-
// CHECK: util.return %[[RESULT]]
79+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
80+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
81+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
82+
// CHECK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
83+
// CHECK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
84+
// CHECK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
85+
// MATMULK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
86+
// MATMULK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [0]>
87+
// MATMULK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = []>
88+
// CHECK-ALL: util.func public @matmul_f32f32f32_parallel_reduce_parallel(
89+
// CHECK-ALL-SAME: %[[ARG0:.+]]: tensor<32x128xf32>
90+
// CHECK-ALL-SAME: %[[ARG1:.+]]: tensor<128x4096xf32>
91+
// CHECK-ALL-SAME: %[[ARG2:.+]]: tensor<4096x32xf32>
92+
// CHECK-ALL: %[[LHS:.+]] = iree_encoding.set_encoding %[[ARG0]]
93+
// CHECK-ALL-SAME: tensor<32x128xf32, #[[LHS_ENCODING]]>
94+
// CHECK-ALL: %[[RHS:.+]] = iree_encoding.set_encoding %[[ARG1]]
95+
// CHECK-ALL-SAME: tensor<128x4096xf32, #[[RHS_ENCODING]]>
96+
// CHECK-ALL: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
97+
// CHECK-ALL-SAME: tensor<4096x32xf32, #[[OUT_ENCODING]]>
98+
// CHECK-ALL: %[[MATMUL:.+]] = linalg.generic
99+
// CHECK-ALL-SAME: ins(%[[LHS]], %[[RHS]] :
100+
// CHECK-ALL-SAME: outs(%[[OUTS]] :
101+
// CHECK-ALL: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<4096x32xf32, #[[OUT_ENCODING]]> -> tensor<4096x32xf32>
102+
// CHECK-ALL: util.return %[[RESULT]]
96103

97104
// -----
98105

compiler/src/iree/compiler/GlobalOptimization/Passes.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ static llvm::cl::opt<DemotionOption> clDemoteContractionInputsToBF16Strategy(
6161
clEnumValN(DemotionOption::None, "none", "Demote no contraction ops.")),
6262
llvm::cl::init(DemotionOption::None));
6363

64+
static llvm::cl::opt<DispatchCreation::EncodingOptions> clSetEncodingStrategy(
65+
"iree-global-opt-set-encoding-strategy",
66+
llvm::cl::desc("Set the encoding strategy for operations."),
67+
llvm::cl::values(
68+
clEnumValN(
69+
DispatchCreation::EncodingOptions::Generic, "generic",
70+
"Using EncodingAttr which encodes as much information as possible"),
71+
clEnumValN(DispatchCreation::EncodingOptions::MatmulK, "matmulk",
72+
"Only encodes the reduction dimenesions in the encoding.")),
73+
llvm::cl::init(DispatchCreation::EncodingOptions::Generic));
74+
6475
static llvm::cl::opt<bool> clWarnOnUninitializedValues(
6576
"iree-global-opt-enable-warn-on-uninitialized-values",
6677
llvm::cl::desc("Warn on some classes of uses of uninitialized values."),
@@ -175,8 +186,10 @@ void buildGlobalOptimizationPassPipeline(
175186

176187
// Enable data tiling after they are in a canonical form.
177188
if (transformOptions.options.dataTiling) {
178-
FunctionLikeNest(mainPassManager)
179-
.addPass(DispatchCreation::createSetEncodingPass);
189+
FunctionLikeNest(mainPassManager).addPass([&]() {
190+
return DispatchCreation::createSetEncodingPass(
191+
DispatchCreation::SetEncodingPassOptions{clSetEncodingStrategy});
192+
});
180193
// TODO(hanchung): Make data-tiling passes be FunctionOpInterface pass, so
181194
// we can use `FunctionLikNest` here.
182195
if (clEnableEarlyMaterialization) {

0 commit comments

Comments
 (0)