Skip to content

Commit ef0d6dd

Browse files
hanhanWkeshavvinayak01
authored andcommitted
[CPU] Tile reduction dimensions for non-root reduction ops. (iree-org#21500)
The revision adds an option to skip root op in LLVMCPUTile pass, and uses it in multi level tiling pipeline. In softmax dispatch, there are two reduction ops. Only the root op is tiled for reduction dimensions when we switched to LLVMCPUTileRootAndFuseInputOperandsPass. It results in large vector sizes in the other reduction op when `util.assume.hint` ops are present. We did not hit the issue in e2e tests because AnnotateDispatchAssumptions pass behaves differently. The value range is [0, 0] if the input is from `flow.tensor.dynamic_constant`. Fixes iree-org#21359 --------- Signed-off-by: hanhanW <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent a7bab8c commit ef0d6dd

File tree

6 files changed

+144
-13
lines changed

6 files changed

+144
-13
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ namespace {
3636
/// lowering_config.
3737
struct LLVMCPUTilePass : impl::LLVMCPUTilePassBase<LLVMCPUTilePass> {
3838
using impl::LLVMCPUTilePassBase<LLVMCPUTilePass>::LLVMCPUTilePassBase;
39-
explicit LLVMCPUTilePass(int64_t tilingLevel) {
40-
this->tilingLevel = tilingLevel;
41-
}
39+
4240
void getDependentDialects(DialectRegistry &registry) const override {
4341
registry.insert<arith::ArithDialect, affine::AffineDialect,
4442
linalg::LinalgDialect, scf::SCFDialect,
@@ -75,8 +73,17 @@ void LLVMCPUTilePass::runOnOperation() {
7573
LDBG("can't find lowering_config, skip tiling");
7674
continue;
7775
}
76+
if (!maybeLoweringConfig.hasTilingLevel(tilingLevel)) {
77+
LDBG("target tiling level does not exist");
78+
continue;
79+
}
7880

7981
LDBG("candidate: " << op);
82+
if (skipRootOp && maybeLoweringConfig.hasWorkgroupTilingLevel()) {
83+
LDBG("skip tiling on the root op");
84+
continue;
85+
}
86+
8087
auto tileSizesAttr = dyn_cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
8188
getLoweringConfig(op).getTilingLevelAttr(tilingLevel));
8289
SmallVector<int64_t> tileSizes(tileSizesAttr.getSizes());
@@ -115,8 +122,11 @@ void LLVMCPUTilePass::runOnOperation() {
115122
} // namespace
116123

117124
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
118-
createLLVMCPUTilePass(int64_t tilingLevel) {
119-
return std::make_unique<LLVMCPUTilePass>(tilingLevel);
125+
createLLVMCPUTilePass(int64_t tilingLevel, bool skipRootOp) {
126+
LLVMCPUTilePassOptions options;
127+
options.tilingLevel = tilingLevel;
128+
options.skipRootOp = skipRootOp;
129+
return std::make_unique<LLVMCPUTilePass>(options);
120130
}
121131

122132
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ void addCPUBufferOpsTileAndVectorizePipeline(
365365

366366
// Skip tiling reduction loops because this is expected to apply on copy ops
367367
// only.
368-
funcPassManager.addPass(
369-
createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
368+
funcPassManager.addPass(createLLVMCPUTilePass(
369+
tilingConfig.getVectorCommonParallelLevel(), /*skipRootOp=*/false));
370370
funcPassManager.addPass(createLLVMCPUPeelPass());
371371
{
372372
GenericVectorizationPassOptions options;
@@ -422,6 +422,11 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
422422
createLLVMCPUSplitReductionPass(clEnableReassociateFpReductions));
423423
funcPassManager.addPass(
424424
createLLVMCPUTileRootAndFuseInputOperandsPass(level));
425+
// Tile all the reduction ops for target vector sizes, which ensures
426+
// that all the dimensions are tiled in all the reduction ops. The root
427+
// op is already tiled, so it is skipped in the pass.
428+
funcPassManager.addPass(createLLVMCPUTilePass(
429+
static_cast<IREE::CPU::TilingLevel>(i), /*skipRootOp=*/true));
425430
break;
426431
case IREE::CPU::TilingLevel::VectorInnerParallelTiles:
427432
funcPassManager.addPass(createLLVMCPUTileAndFusePass(
@@ -603,8 +608,8 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager,
603608
funcPassManager.addPass(
604609
createCPULowerToUKernelsPass(clSkipIntermediateRoundings));
605610

606-
funcPassManager.addPass(
607-
createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
611+
funcPassManager.addPass(createLLVMCPUTilePass(
612+
tilingConfig.getVectorCommonParallelLevel(), /*skipRootOp=*/false));
608613
if (pipelineOpt.decomposePackUnPackOps) {
609614
funcPassManager.addPass(createDecomposePackUnPackOpsPass());
610615
}

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
3838
createLLVMCPUSplitReductionPass(bool enableReassociateFpReductions);
3939

4040
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
41-
createLLVMCPUTilePass(int64_t tilingLevel);
41+
createLLVMCPUTilePass(int64_t tilingLevel, bool skipRootOp);
4242

4343
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
4444
createLLVMCPUTileAndFusePass(int64_t tilingLevel);

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ def LLVMCPUTilePass :
138138
}];
139139
let options = [
140140
Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
141-
"Tiling level used to retrieve the configuration from lowering_config">
141+
"Tiling level used to retrieve the configuration from lowering_config.">,
142+
Option<"skipRootOp", "skip-root-op", "bool", /*default=*/"false",
143+
"Do not tile the root op if the option is true.">
142144
];
143145
}
144146

compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,3 +579,79 @@ func.func @pooling_nchw_max_pack_with_padding_issue_20723() attributes {hal.exec
579579
// CHECK: iree_linalg_ext.map_scatter
580580
// CHECK: } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
581581
// CHECK: scf.forall
582+
583+
// -----
584+
585+
// Verify that the dispatch can be compiled without creating large vectors.
586+
587+
#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver4", cpu_features = "", max_stack_allocation_size = 32768 : i64, native_vector_size = 64 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
588+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
589+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
590+
#pipeline_layout = #hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
591+
func.func @softmax_dynamic_with_assume_int_hints() attributes {hal.executable.target = #executable_target_embedded_elf_x86_64} {
592+
%cst = arith.constant 0.000000e+00 : f32
593+
%cst_0 = arith.constant 0xFFC00000 : f32
594+
%c1 = arith.constant 1 : index
595+
%c32_i64 = arith.constant 32 : i64
596+
%c0 = arith.constant 0 : index
597+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
598+
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32
599+
%2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : i32
600+
%3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : i32
601+
%4 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : i32
602+
%5 = hal.interface.constant.load layout(#pipeline_layout) ordinal(5) : i32
603+
%6 = arith.extui %0 : i32 to i64
604+
%7 = arith.extui %1 : i32 to i64
605+
%8 = arith.shli %7, %c32_i64 : i64
606+
%9 = arith.ori %6, %8 : i64
607+
%10 = arith.index_castui %9 : i64 to index
608+
%11 = arith.extui %2 : i32 to i64
609+
%12 = arith.extui %3 : i32 to i64
610+
%13 = arith.shli %12, %c32_i64 : i64
611+
%14 = arith.ori %11, %13 : i64
612+
%15 = arith.index_castui %14 : i64 to index
613+
%16 = arith.extui %4 : i32 to i64
614+
%17 = arith.extui %5 : i32 to i64
615+
%18 = arith.shli %17, %c32_i64 : i64
616+
%19 = arith.ori %16, %18 : i64
617+
%20 = arith.index_castui %19 : i64 to index
618+
%21:3 = util.assume.int
619+
%10<umin = 0, umax = 9007199254740991>,
620+
%15<umin = 0, umax = 9007199254740991>,
621+
%20<umin = 0, umax = 9007199254740991>
622+
: index, index, index
623+
%22 = iree_tensor_ext.dispatch.workload.ordinal %21#0, 0 : index
624+
%23 = iree_tensor_ext.dispatch.workload.ordinal %21#1, 1 : index
625+
%24 = iree_tensor_ext.dispatch.workload.ordinal %21#2, 2 : index
626+
%25 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?x?xf32>>{%22, %23, %24}
627+
%26 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x?x?xf32>>{%22, %23, %24}
628+
%27 = iree_tensor_ext.dispatch.tensor.load %25, offsets = [0, 0, 0], sizes = [%22, %23, %24], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?x?xf32>>{%22, %23, %24} -> tensor<?x?x?xf32>
629+
%28 = tensor.empty(%22, %23, %24) : tensor<?x?x?xf32>
630+
%dim = tensor.dim %27, %c0 : tensor<?x?x?xf32>
631+
%dim_1 = tensor.dim %27, %c1 : tensor<?x?x?xf32>
632+
%29 = tensor.empty(%dim, %dim_1) : tensor<?x?xf32>
633+
%30 = linalg.fill ins(%cst_0 : f32) outs(%29 : tensor<?x?xf32>) -> tensor<?x?xf32>
634+
%31 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%27 : tensor<?x?x?xf32>) outs(%30 : tensor<?x?xf32>) {
635+
^bb0(%in: f32, %out: f32):
636+
%35 = arith.maxnumf %in, %out : f32
637+
linalg.yield %35 : f32
638+
} -> tensor<?x?xf32>
639+
%32 = linalg.fill ins(%cst : f32) outs(%29 : tensor<?x?xf32>) -> tensor<?x?xf32>
640+
%33 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%27, %31 : tensor<?x?x?xf32>, tensor<?x?xf32>) outs(%32 : tensor<?x?xf32>) {
641+
^bb0(%in: f32, %in_2: f32, %out: f32):
642+
%35 = arith.subf %in, %in_2 : f32
643+
%36 = math.exp %35 : f32
644+
%37 = arith.addf %36, %out : f32
645+
linalg.yield %37 : f32
646+
} -> tensor<?x?xf32>
647+
%34 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%27, %31, %33 : tensor<?x?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) outs(%28 : tensor<?x?x?xf32>) {
648+
^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
649+
%35 = arith.subf %in, %in_2 : f32
650+
%36 = math.exp %35 : f32
651+
%37 = arith.divf %36, %in_3 : f32
652+
linalg.yield %37 : f32
653+
} -> tensor<?x?x?xf32>
654+
iree_tensor_ext.dispatch.tensor.store %34, %26, offsets = [0, 0, 0], sizes = [%22, %23, %24], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x?x?xf32>>{%22, %23, %24}
655+
return
656+
}
657+
// CHECK-LABEL: func.func @softmax_dynamic_with_assume_int_hints(

compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile.mlir

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
// `TilingLevel=0` indicates DistributionTiles in IREE::CPU::LoweringConfigAttr.
12
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile{tiling-level=0}))" --split-input-file %s | FileCheck %s
3+
// `TilingLevel=4` indicates VectorCommonParallelTiles in IREE::CPU::LoweringConfigAttr.
4+
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile{tiling-level=3 skip-root-op=true}))" --split-input-file %s | FileCheck %s --check-prefix=SKIP-ROOT
25

3-
// `tiling-level=0`, which is the testing value of the pass option, indicates
4-
// distribution level tiling.
56
#config0 = #iree_cpu.lowering_config<distribution = [10, 20]>
67
#config1 = #iree_codegen.lowering_config<tile_sizes = [[10, 20, 30]]>
78
func.func @matmul_bias_add(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
@@ -98,3 +99,40 @@ func.func @do_not_tile_ukernel(%arg0: tensor<?x?x16x1xf32>, %arg1: tensor<?x?x16
9899
// CHECK-LABEL: func.func @do_not_tile_ukernel
99100
// CHECK-NOT: scf.for
100101
// CHECK: iree_codegen.ukernel.generic
102+
103+
// -----
104+
105+
#config0 = #iree_cpu.lowering_config<vector_common_parallel = [10, 20]>
106+
#config1 = #iree_cpu.lowering_config<distribution = [10, 20, 30]>
107+
func.func @matmul_bias_add_skip_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
108+
%cst = arith.constant 0.0 : f32
109+
%c0 = arith.constant 0 : index
110+
%c1 = arith.constant 1 : index
111+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
112+
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
113+
%init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
114+
%0 = linalg.fill {lowering_config = #config0} ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
115+
%1 = linalg.matmul {lowering_config = #config1}
116+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
117+
outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
118+
%2 = linalg.generic {
119+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1)-> (d0, d1)>],
120+
iterator_types = ["parallel", "parallel"]}
121+
ins(%1, %arg2 : tensor<?x?xf32>, tensor<?xf32>)
122+
outs(%init : tensor<?x?xf32>) attrs = {lowering_config = #config0} {
123+
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
124+
%3 = arith.addf %arg3, %arg4 : f32
125+
linalg.yield %3 : f32
126+
} -> tensor<?x?xf32>
127+
return %2 : tensor<?x?xf32>
128+
}
129+
// SKIP-ROOT: func.func @matmul_bias_add_skip_matmul
130+
// SKIP-ROOT: scf.for
131+
// SKIP-ROOT: scf.for
132+
// SKIP-ROOT: linalg.fill
133+
// SKIP-ROOT: scf.yield
134+
// SKIP-ROOT: scf.yield
135+
// SKIP-ROOT: linalg.matmul
136+
// SKIP-ROOT: scf.for
137+
// SKIP-ROOT: scf.for
138+
// SKIP-ROOT: linalg.generic

0 commit comments

Comments
 (0)