Skip to content

Commit 62f8d7e

Browse files
authored
Revert "[CPU] Enable tileDispatchUsingForall for mmt4d and convolution pipelines. " (#18707)
Reverts #18618 It breaks `PkgCI / Regression Test / test_models :: cpu_llvm_task (push)`. I think we should address the issue before landing the patch. Sample log: https://github.com/iree-org/iree/actions/runs/11199607624/job/31132512036
1 parent 8c9f2cb commit 62f8d7e

File tree

4 files changed

+38
-53
lines changed

4 files changed

+38
-53
lines changed

compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
2121
#include "mlir/Dialect/Affine/Utils.h"
2222
#include "mlir/Dialect/Arith/Utils/Utils.h"
23-
2423
namespace mlir::iree_compiler {
2524

2625
#define GEN_PASS_DEF_RECONCILETRANSLATIONINFOPASS
@@ -264,6 +263,10 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
264263
return success();
265264
}
266265

266+
if (!llvm::hasSingleElement(body)) {
267+
return funcOp.emitOpError("unhandled function with multiple blocks");
268+
}
269+
267270
auto forAllOps = body.getOps<scf::ForallOp>();
268271
SmallVector<scf::ForallOp> workgroupForAllOps = llvm::to_vector(
269272
llvm::make_filter_range(forAllOps, [&](scf::ForallOp forAllOp) {
@@ -292,10 +295,6 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
292295
"scf.forall ops withing the function");
293296
}
294297

295-
if (!llvm::hasSingleElement(body)) {
296-
return funcOp.emitOpError("unhandled function with multiple blocks");
297-
}
298-
299298
scf::ForallOp forallOp = *forAllOps.begin();
300299
if (failed(resolveWorkgroupCount(rewriter, funcOp, forallOp))) {
301300
return failure();
@@ -360,10 +359,9 @@ void ReconcileTranslationInfoPass::runOnOperation() {
360359
auto innerModuleOp = variantOp.getInnerModule();
361360

362361
auto exportOps = variantOp.getOps<IREE::HAL::ExecutableExportOp>();
363-
364-
// reconciliation for multiple export ops is unsupported.
365362
if (!llvm::hasSingleElement(exportOps)) {
366-
return;
363+
variantOp.emitOpError("reconciliation for multiple export ops unsupported");
364+
return signalPassFailure();
367365
}
368366
auto exportOp = *exportOps.begin();
369367
IRRewriter rewriter(&getContext());

compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-reconcile-translation-info, canonicalize)))" %s --verify-diagnostics --allow-unregistered-dialect | FileCheck %s
22

3+
#pipeline_layout = #hal.pipeline.layout<bindings = [
4+
#hal.pipeline.binding<storage_buffer>
5+
]>
6+
hal.executable private @err_multiple_entry_point {
7+
// expected-error @+1 {{reconciliation for multiple export ops unsupported}}
8+
hal.executable.variant public @reconcile_workgroup_size target(#hal.executable.target<"", "", {}>) {
9+
hal.executable.export public @entry_point1 layout(#pipeline_layout)
10+
hal.executable.export public @entry_point2 layout(#pipeline_layout)
11+
}
12+
}
13+
14+
// -----
15+
316
#pipeline_layout = #hal.pipeline.layout<bindings = [
417
#hal.pipeline.binding<storage_buffer>
518
]>

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,6 @@ static llvm::cl::opt<bool> clEnableVectorContractCustomKernels(
9191
"LLVMCPUMmt4dVectorLowering pass."),
9292
llvm::cl::init(false));
9393

94-
static llvm::cl::opt<bool> clTileDispatchUsingForall(
95-
"iree-llvmcpu-tile-dispatch-using-forall",
96-
llvm::cl::desc("Enable tile and distribute to workgroups using scf.forall"),
97-
llvm::cl::init(false));
98-
9994
// By default, IREE does not enable the Armv9-A streaming SVE mode in the
10095
// presence of scalable vectors (even when using `+sme`), as currently there's
10196
// no cost model of when it could be beneficial. This flag will effectively make
@@ -109,18 +104,11 @@ static llvm::cl::opt<bool> clForceArmStreaming(
109104
"than SVE). Requires the +sme feature flag."),
110105
llvm::cl::init(false));
111106

112-
// TODO: Enable `TileDispatchUsingForall` for every pipeline.
113-
static void addTileAndDistributePasses(OpPassManager &funcPassManager,
114-
bool enableTileDispatchUsingForall) {
115-
if (enableTileDispatchUsingForall || clTileDispatchUsingForall) {
116-
funcPassManager.addPass(
117-
createTileAndDistributeToWorkgroupsUsingForallOpPass());
118-
} else {
119-
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass());
120-
funcPassManager.addPass(createCSEPass());
121-
funcPassManager.addPass(createConvertToDestinationPassingStylePass());
122-
funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass());
123-
}
107+
static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
108+
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass());
109+
funcPassManager.addPass(createCSEPass());
110+
funcPassManager.addPass(createConvertToDestinationPassingStylePass());
111+
funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass());
124112
funcPassManager.addPass(createCanonicalizerPass());
125113
funcPassManager.addPass(createCSEPass());
126114
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
@@ -345,8 +333,7 @@ void buildLLVMCPUVectorLoweringPipeline(
345333
void addCPUBufferOpsTileAndVectorizePipeline(
346334
OpPassManager &funcPassManager, TilingConfig &tilingConfig,
347335
LLVMCPUPipelineOptions &pipelineOpt) {
348-
addTileAndDistributePasses(funcPassManager,
349-
/*enableTileDispatchUsingForall=*/true);
336+
addTileAndDistributePasses(funcPassManager);
350337

351338
// Skip tiling reduction loops because this is expected to apply on copy ops
352339
// only.
@@ -383,8 +370,7 @@ void addCPUBufferOpsTileAndVectorizePipeline(
383370
void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
384371
TilingConfig &tilingConfig,
385372
LLVMCPUPipelineOptions &pipelineOpt) {
386-
addTileAndDistributePasses(funcPassManager,
387-
/*enableTileDispatchUsingForall=*/false);
373+
addTileAndDistributePasses(funcPassManager);
388374

389375
SmallVector<int64_t> allFusableLevels(tilingConfig.getFusableLevels());
390376
// Apply tile and fuse to all the non-distribution fusable levels. Skip
@@ -463,8 +449,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
463449
void addConvTileAndDecomposeExpertPassPipeline(
464450
OpPassManager &funcPassManager, TilingConfig &tilingConfig,
465451
LLVMCPUPipelineOptions &pipelineOpt) {
466-
addTileAndDistributePasses(funcPassManager,
467-
/*enableTileDispatchUsingForall=*/true);
452+
addTileAndDistributePasses(funcPassManager);
468453

469454
// Run LLVMTileAndFuse firstly in case that we have fill + conv + generic
470455
// ops. At this stage, we do not apply vectorization. The reduction dim won't
@@ -527,8 +512,7 @@ void addConvTileAndDecomposeExpertPassPipeline(
527512
void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
528513
TilingConfig &tilingConfig,
529514
LLVMCPUPipelineOptions &pipelineOpt) {
530-
addTileAndDistributePasses(funcPassManager,
531-
/*enableTileDispatchUsingForall=*/true);
515+
addTileAndDistributePasses(funcPassManager);
532516

533517
funcPassManager.addPass(createLLVMCPUTileAndFusePass(
534518
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
@@ -576,8 +560,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
576560
void addCPUDataTilingPipeline(OpPassManager &funcPassManager,
577561
TilingConfig &tilingConfig,
578562
LLVMCPUPipelineOptions &pipelineOpt) {
579-
addTileAndDistributePasses(funcPassManager,
580-
/*enableTileDispatchUsingForall=*/true);
563+
addTileAndDistributePasses(funcPassManager);
581564

582565
// The below two passes are nop if pack/unpack is not specified in ukernels
583566
// attribute. By default, they are disabled.
@@ -620,8 +603,7 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager,
620603
void addCPULinalgExtTileAndVectorizePipeline(
621604
OpPassManager &funcPassManager, TilingConfig &tilingConfig,
622605
LLVMCPUPipelineOptions &pipelineOpt) {
623-
addTileAndDistributePasses(funcPassManager,
624-
/*enableTileDispatchUsingForall=*/false);
606+
addTileAndDistributePasses(funcPassManager);
625607
funcPassManager.addPass(
626608
createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
627609
// TODO: Remove the pass once we have PartialReductionOpInterface implemented
@@ -660,8 +642,7 @@ void addCPULinalgExtTileAndVectorizePipeline(
660642
}
661643

662644
void addCPUDefaultPassPipeline(OpPassManager &funcPassManager) {
663-
addTileAndDistributePasses(funcPassManager,
664-
/*enableTileDispatchUsingForall=*/false);
645+
addTileAndDistributePasses(funcPassManager);
665646
addCPUBufferizePasses(funcPassManager);
666647
}
667648

@@ -809,21 +790,13 @@ void buildLLVMCPUCodegenConfigurationPassPipeline(
809790

810791
void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager,
811792
bool enableAArch64SME) {
812-
813-
{
814-
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
815-
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
816-
FunctionLikeNest(modulePassManager)
817-
.addPass(createLLVMCPULowerExecutableTargetPass);
818-
}
819-
820-
variantPassManager.addPass(createReconcileTranslationInfoPass());
793+
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
794+
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
795+
FunctionLikeNest(modulePassManager)
796+
.addPass(createLLVMCPULowerExecutableTargetPass);
821797

822798
// Run conversion to LLVM at `ModuleOp` granularity.
823-
{
824-
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
825-
addLowerToLLVMPasses(modulePassManager, enableAArch64SME);
826-
}
799+
addLowerToLLVMPasses(modulePassManager, enableAArch64SME);
827800
LLVM_DEBUG({
828801
llvm::dbgs() << "LLVMCPU codegen pass pipeline:\n";
829802
variantPassManager.printAsTextualPipeline(llvm::dbgs());

compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ func.func @ukernel_dispatch() attributes {hal.executable.target = #executable_ta
276276
}
277277
// CHECK-LABEL: func @ukernel_dispatch()
278278
// Checks scf.for for distribution loops.
279-
// CHECK: scf.forall
279+
// CHECK: scf.for
280+
// CHECK: scf.for
280281
// Checks scf.for for outer and inner parallel loops.
281282
// CHECK: scf.for
282283
// CHECK: scf.for

0 commit comments

Comments
 (0)