Skip to content

Commit f8c6ebb

Browse files
[CPU] Use option to tile with scf.forall in TileRootAndFuseProducerConsumer pass (#21198)
This PR is part of a larger change moving LLVMCPUTileRootAndFuseProducerConsumer's tiling of parallel dimensions to use scf.forall rather than scf.for; allowing us to take advantage of some canonicalization patterns available on scf.forall. The patterns available prevent redundant stack allocations seen in #20792. Tiling parallel dimensions with scf.forall rather than scf.for could also be said to be semantically cleaner. As part of that change, this PR uses the new option to TileRootAndFuseProducerConsumer in the mmt4d pipeline which will tile with scf.forall. As later patterns in the pipeline expect `scf.for`, this PR also adds pass to convert back to `scf.for` from `scf.forall` to the mmt4d pipeline. PR 4/4 addressing #20792
1 parent 88a09f9 commit f8c6ebb

File tree

5 files changed

+64
-7
lines changed

5 files changed

+64
-7
lines changed

compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ void GenericVectorizationPass::runOnOperation() {
143143
std::optional<SizesAndScalableFlags> vectorSizesAndScalableDims =
144144
getVectorSizes(op, useConfiguredVectorSizes);
145145
if (vectorSizesAndScalableDims) {
146-
auto [sizes, scalableDims] = *vectorSizesAndScalableDims;
147-
vectorSizes.append(sizes.begin(), sizes.end());
148-
scalableVecDims.append(scalableDims.begin(), scalableDims.end());
146+
std::tie(vectorSizes, scalableVecDims) = *vectorSizesAndScalableDims;
149147
}
150148
}
151149

compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,12 @@ void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() {
242242
} // namespace
243243

244244
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
245-
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
245+
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel,
246+
bool tileUsingForAll) {
246247
LLVMCPUTileRootAndFuseProducerConsumerPassOptions options;
247248
options.tilingLevel = tilingLevel;
248249
options.onlyFuseProducerInputOperands = false;
249-
options.tileUsingForall = false;
250+
options.tileUsingForall = tileUsingForAll;
250251
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(options);
251252
}
252253
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,17 +532,20 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
532532
addTileAndDistributePasses(funcPassManager);
533533

534534
funcPassManager.addPass(createLLVMCPUTileRootAndFuseProducerConsumer(
535-
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
535+
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel()),
536+
/*tileUsingForall=*/true));
536537
// The below two passes are nop if the "mmt4d" is explicitly excluded in the
537538
// ukernels attribute.
538539
funcPassManager.addPass(createCPUPrepareUkernelsPass());
539540
funcPassManager.addPass(
540541
createCPULowerToUKernelsPass(clSkipIntermediateRoundings));
541542
funcPassManager.addPass(createLLVMCPUTileRootAndFuseInputOperands(
542543
static_cast<int64_t>(tilingConfig.getVectorReductionLevel())));
544+
funcPassManager.addPass(iree_compiler::createForallToForPass());
543545

544546
{
545547
GenericVectorizationPassOptions options;
548+
options.useConfiguredVectorSizes = pipelineOpt.useConfiguredVectorSizes;
546549
options.enableVectorMasking = pipelineOpt.enableVectorMasking;
547550
options.vectorizePadding = true;
548551
options.vectorizeGatherAccesses = true;

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
4343
createLLVMCPUTileAndFusePass(int64_t tilingLevel);
4444

4545
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
46-
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel);
46+
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel,
47+
bool tileUsingForall);
4748

4849
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
4950
createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel);

compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,57 @@ func.func @fuse_inputs_reduction() attributes {hal.executable.target = #executab
424424
// CHECK: vector.load
425425
// CHECK-NOT: scf.for
426426
// CHECK: arith.addf
427+
428+
// -----
429+
430+
#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", max_stack_allocation_size = 32768 : i64, native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "none"}>
431+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
432+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
433+
#pipeline_layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
434+
module {
435+
func.func @mmt4d_bias_relu() attributes {hal.executable.target = #executable_target_embedded_elf_x86_64} {
436+
%c0 = arith.constant 0 : index
437+
%c32_i64 = arith.constant 32 : i64
438+
%cst = arith.constant 0.000000e+00 : f32
439+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
440+
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32
441+
%2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : i32
442+
%3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : i32
443+
%4 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : i32
444+
%5 = arith.index_castui %0 : i32 to index
445+
%6 = arith.index_castui %1 : i32 to index
446+
%7 = arith.index_castui %2 : i32 to index
447+
%8 = arith.index_castui %3 : i32 to index
448+
%9 = arith.index_castui %4 : i32 to index
449+
%36 = iree_tensor_ext.dispatch.workload.ordinal %5, 0 : index
450+
%37 = iree_tensor_ext.dispatch.workload.ordinal %6, 1 : index
451+
%38 = iree_tensor_ext.dispatch.workload.ordinal %7, 2 : index
452+
%39 = iree_tensor_ext.dispatch.workload.ordinal %8, 3 : index
453+
%40 = iree_tensor_ext.dispatch.workload.ordinal %9, 4 : index
454+
%41 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%39, %36}
455+
%42 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%37, %40}
456+
%43 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x16xf32>>{%38}
457+
%44 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x?x16x16xf32>>{%39, %40}
458+
%45 = iree_tensor_ext.dispatch.tensor.load %41, offsets = [0, 0, 0, 0], sizes = [%39, %36, 16, 1], strides = [1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%39, %36} -> tensor<?x?x16x1xf32>
459+
%46 = iree_tensor_ext.dispatch.tensor.load %42, offsets = [0, 0, 0, 0], sizes = [%37, %40, 16, 1], strides = [1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?x16x1xf32>>{%37, %40} -> tensor<?x?x16x1xf32>
460+
%47 = iree_tensor_ext.dispatch.tensor.load %43, offsets = [0, 0], sizes = [%38, 16], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x16xf32>>{%38} -> tensor<?x16xf32>
461+
%48 = tensor.empty(%39, %40) : tensor<?x?x16x16xf32>
462+
%49 = linalg.fill ins(%cst : f32) outs(%48 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
463+
%50 = linalg.mmt4d ins(%45, %46 : tensor<?x?x16x1xf32>, tensor<?x?x16x1xf32>) outs(%49 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
464+
%51 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%50, %47 : tensor<?x?x16x16xf32>, tensor<?x16xf32>) outs(%48 : tensor<?x?x16x16xf32>) {
465+
^bb0(%in: f32, %in_0: f32, %out: f32):
466+
%52 = arith.addf %in, %in_0 : f32
467+
%53 = arith.maximumf %52, %cst : f32
468+
linalg.yield %53 : f32
469+
} -> tensor<?x?x16x16xf32>
470+
iree_tensor_ext.dispatch.tensor.store %51, %44, offsets = [0, 0, 0, 0], sizes = [%39, %40, 16, 16], strides = [1, 1, 1, 1] : tensor<?x?x16x16xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x?x16x16xf32>>{%39, %40}
471+
return
472+
}
473+
}
474+
// CHECK-LABEL: func.func @mmt4d_bias_relu
475+
// CHECK-NOT: memref.alloc
476+
// CHECK: scf.forall
477+
// CHECK: scf.for
478+
// CHECK: vector.fma
479+
// CHECK: vector.insert
480+
// CHECK: arith.addf

0 commit comments

Comments
 (0)