Skip to content

Commit cd29f38

Browse files
authored
[CPU] Use scf.forall for TileRootAndFuseProducerConsumer by default. (iree-org#21260)
The revision drops the option and switch to scf.forall by default, when tile and fuse the parallel dimensions. To finish the migration, it updates the LinalgExt pipeline and adds the ForallToFor pass before vectorization. Signed-off-by: hanhanW <[email protected]>
1 parent e0b184c commit cd29f38

File tree

5 files changed

+24
-52
lines changed

5 files changed

+24
-52
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ namespace mlir::iree_compiler {
3737
/// the root operation and fuse the producers of the root operation then
3838
/// consumers (finds any missing fusion opportunities, then apply producer
3939
/// fusion). If `onlyFuseProducerInputOperands` is set, only fuse producer input
40-
/// operands. If `tileUsingForall` is set, creates `scf.forall`, rather than
41-
/// `scf.for` loops during tiling.
42-
static FailureOr<Operation *> tileRootAndFuseProducerConsumer(
43-
IRRewriter &rewriter, TilingInterface rootOp, int64_t tilingLevel,
44-
bool onlyFuseProducerInputOperands, bool tileUsingForall) {
40+
/// operands.
41+
static FailureOr<Operation *>
42+
tileRootAndFuseProducerConsumer(IRRewriter &rewriter, TilingInterface rootOp,
43+
int64_t tilingLevel,
44+
bool onlyFuseProducerInputOperands) {
4545
auto *context = rewriter.getContext();
4646
mlir::DominanceInfo dominanceInfo(rootOp);
4747
llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
@@ -88,7 +88,7 @@ static FailureOr<Operation *> tileRootAndFuseProducerConsumer(
8888
tilingOptions.setTileSizes(tileSizes);
8989

9090
// onlyFuseProducerInputOperands implies reduction tiling.
91-
if (tileUsingForall && !onlyFuseProducerInputOperands) {
91+
if (!onlyFuseProducerInputOperands) {
9292
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
9393
}
9494

@@ -218,7 +218,7 @@ void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() {
218218

219219
if (failed(tileRootAndFuseProducerConsumer(
220220
rewriter, cast<TilingInterface>(rootOp.value()), tilingLevel,
221-
onlyFuseProducerInputOperands, tileUsingForall))) {
221+
onlyFuseProducerInputOperands))) {
222222
funcOp.emitError() << "tiling of level " << tilingLevel.getValue()
223223
<< " failed\n";
224224
return signalPassFailure();
@@ -242,20 +242,17 @@ void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() {
242242
} // namespace
243243

244244
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
245-
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel,
246-
bool tileUsingForAll) {
245+
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
247246
LLVMCPUTileRootAndFuseProducerConsumerPassOptions options;
248247
options.tilingLevel = tilingLevel;
249248
options.onlyFuseProducerInputOperands = false;
250-
options.tileUsingForall = tileUsingForAll;
251249
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(options);
252250
}
253251
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
254252
createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel) {
255253
LLVMCPUTileRootAndFuseProducerConsumerPassOptions options;
256254
options.tilingLevel = tilingLevel;
257255
options.onlyFuseProducerInputOperands = true;
258-
options.tileUsingForall = false;
259256
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(options);
260257
}
261258
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
532532
addTileAndDistributePasses(funcPassManager);
533533

534534
funcPassManager.addPass(createLLVMCPUTileRootAndFuseProducerConsumer(
535-
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel()),
536-
/*tileUsingForall=*/true));
535+
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
537536
// The below two passes are nop if the "mmt4d" is explicitly excluded in the
538537
// ukernels attribute.
539538
funcPassManager.addPass(createCPUPrepareUkernelsPass());
@@ -647,6 +646,7 @@ void addCPULinalgExtTileAndVectorizePipeline(
647646
funcPassManager.addPass(
648647
IREE::LinalgExt::createDecomposeWinogradTransformPass());
649648
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
649+
funcPassManager.addPass(iree_compiler::createForallToForPass());
650650

651651
{
652652
GenericVectorizationPassOptions options;

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
4343
createLLVMCPUTileAndFusePass(int64_t tilingLevel);
4444

4545
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
46-
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel,
47-
bool tileUsingForall);
46+
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel);
4847

4948
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
5049
createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel);

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,7 @@ def LLVMCPUTileRootAndFuseProducerConsumerPass
160160
"only-fuse-producer-input-operands", "bool",
161161
/*default=*/"false",
162162
"Specifies if we only want to fuse producer's input operands. "
163-
"This is helpful to tile&fuse in case of reduction dimensions.">,
164-
Option<"tileUsingForall", "tile-using-forall", "bool",
165-
/*default=*/"false",
166-
"Tile parallel dimensions using `scf.forall` instead of `scf.for`. Reduction dimension defaults to `scf.for`.">];
163+
"This is helpful to tile&fuse in case of reduction dimensions.">];
167164
}
168165

169166
def LLVMCPUVerifyVectorSizeLegalityPass :

compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_root_and_fuse_producer_consumer.mlir

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0}), cse)" --split-input-file %s | FileCheck %s
2-
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0 tile-using-forall=true}), cse)" --split-input-file %s | FileCheck %s --check-prefix=CHECK-FORALL
32
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=2 only-fuse-producer-input-operands=true}), cse)" --split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION
43

54
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
@@ -24,21 +23,12 @@ func.func @mmt4d_bias_relu(%arg0: tensor<?x?x16x1xf32>, %arg1: tensor<?x?x16x1xf
2423
return %4 : tensor<?x?x16x16xf32>
2524
}
2625
// CHECK-LABEL: func.func @mmt4d_bias_relu(
27-
// CHECK: scf.for
26+
// CHECK: scf.forall
2827
// CHECK: linalg.fill
2928
// CHECK-NEXT: %[[MMT4D:.+]] = linalg.mmt4d
3029
// CHECK: %[[ELEM:.+]] = linalg.generic
31-
// CHECK: %[[RES0:.+]] = tensor.insert_slice %[[MMT4D]]
32-
// CHECK: %[[RES1:.+]] = tensor.insert_slice %[[ELEM]]
33-
// CHECK: scf.yield %[[RES0]], %[[RES1]]
34-
35-
// CHECK-FORALL-LABEL: func.func @mmt4d_bias_relu(
36-
// CHECK-FORALL: scf.forall
37-
// CHECK-FORALL: linalg.fill
38-
// CHECK-FORALL-NEXT: %[[MMT4D:.+]] = linalg.mmt4d
39-
// CHECK-FORALL: %[[ELEM:.+]] = linalg.generic
40-
// CHECK-FORALL: scf.forall.in_parallel
41-
// CHECK-FORALL: tensor.parallel_insert_slice %[[ELEM]]
30+
// CHECK: scf.forall.in_parallel
31+
// CHECK: tensor.parallel_insert_slice %[[ELEM]]
4232

4333
// -----
4434

@@ -72,26 +62,15 @@ func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16
7262
%unpack = linalg.unpack %6 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 16] into %7 : tensor<2x4x688x16x16xf32> -> tensor<2x11008x64xf32>
7363
return %unpack : tensor<2x11008x64xf32>
7464
}
75-
// CHECK: func.func @quantized_matmul(
76-
// CHECK: scf.for
77-
// CHECK: linalg.generic
78-
// CHECK: linalg.generic
79-
// CHECK: linalg.fill
80-
// CHECK: %[[MMT4D:.+]] = linalg.batch_mmt4d
81-
// CHECK: %[[UNPACK:.+]] = linalg.unpack
82-
// CHECK: %[[RES0:.+]] = tensor.insert_slice %[[MMT4D]]
83-
// CHECK: %[[RES1:.+]] = tensor.insert_slice %[[UNPACK]]
84-
// CHECK: scf.yield %[[RES0]], %[[RES1]]
85-
86-
// CHECK-FORALL-LABEL: func.func @quantized_matmul(
87-
// CHECK-FORALL: scf.forall
88-
// CHECK-FORALL: linalg.generic
89-
// CHECK-FORALL: linalg.generic
90-
// CHECK-FORALL: linalg.fill
91-
// CHECK-FORALL: %[[MMT4D:.+]] = linalg.batch_mmt4d
92-
// CHECK-FORALL: %[[UNPACK:.+]] = linalg.unpack
93-
// CHECK-FORALL: scf.forall.in_parallel
94-
// CHECK-FORALL: tensor.parallel_insert_slice %[[UNPACK]]
65+
// CHECK-LABEL: func.func @quantized_matmul(
66+
// CHECK: scf.forall
67+
// CHECK: linalg.generic
68+
// CHECK: linalg.generic
69+
// CHECK: linalg.fill
70+
// CHECK: %[[MMT4D:.+]] = linalg.batch_mmt4d
71+
// CHECK: %[[UNPACK:.+]] = linalg.unpack
72+
// CHECK: scf.forall.in_parallel
73+
// CHECK: tensor.parallel_insert_slice %[[UNPACK]]
9574

9675
// -----
9776

0 commit comments

Comments
 (0)