Skip to content

Commit 9aae362

Browse files
authored
[Codegen] Sprinkle in PropagateDispatchSizeBounds passes (iree-org#19677)
Since the various tiling and distribution don't know how to set the upper bounds on workitem or workgroup IDs - even if that information is known from context, we use the PropagateDispatchSizeBounds pass to add that information before passes that use it. The mani passes that use this information are those that use the ValueBoundsOpInterface - that is, loop invariant code motion, some vectorization code, and, in an upcoming commit, RemoveSingleIterationLoop. These calls can be removed in the future, but they'll do for now.
1 parent a64d713 commit 9aae362

File tree

7 files changed

+40
-15
lines changed

7 files changed

+40
-15
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ static std::pair<Value, Value> makeTransposedIds(Location loc, OpBuilder b,
4444
/// Returns the workgroup counts along the X and Y dimensions. These will be
4545
/// constants when static in the corresponding `hal.executable.export` op.
4646
static std::pair<Value, Value>
47-
getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp) {
47+
getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp,
48+
std::optional<APInt> xBound, std::optional<APInt> yBound) {
4849
Location loc = funcOp.getLoc();
4950
SmallVector<int64_t> workgroupCounts = getStaticNumWorkgroups(funcOp);
5051
bool isStaticWgCount = llvm::none_of(workgroupCounts, ShapedType::isDynamic);
@@ -62,9 +63,9 @@ getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp) {
6263

6364
LLVM_DEBUG(llvm::dbgs() << "Using dynamic workgroup counts\n");
6465
Value dynamicCountX =
65-
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0);
66+
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0, xBound);
6667
Value dynamicCountY =
67-
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1);
68+
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1, yBound);
6869
return {dynamicCountX, dynamicCountY};
6970
}
7071

@@ -100,11 +101,12 @@ reorderWorkgroupsInFunc(FunctionOpInterface funcOp,
100101
// that to RAUW the old ones. This way we don't have to worry about the
101102
// picking the exact insertion points that do not violate dominance between
102103
// their defs and users.
103-
Value workgroupIdX =
104-
builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 0);
105-
Value workgroupIdY =
106-
builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 1);
107-
auto [workgroupCntX, workgroupCntY] = getWorkgroupCountsXY(builder, funcOp);
104+
Value workgroupIdX = builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(
105+
funcOp.getLoc(), 0, oldXId.getUpperBound());
106+
Value workgroupIdY = builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(
107+
funcOp.getLoc(), 1, oldYId.getUpperBound());
108+
auto [workgroupCntX, workgroupCntY] = getWorkgroupCountsXY(
109+
builder, funcOp, oldXId.getUpperBound(), oldYId.getUpperBound());
108110
Value newWorkgroupIdX;
109111
Value newWorkgroupIdY;
110112
assert(strategy == ReorderWorkgroupsStrategy::Transpose &&

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
126126
funcPassManager.addPass(createCSEPass());
127127
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
128128
funcPassManager.addPass(createConcretizePadResultShapePass());
129+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
129130
}
130131

131132
//===---------------------------------------------------------------------===//
@@ -447,6 +448,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
447448
addCPUBufferizePasses(funcPassManager);
448449

449450
// Run IREE specific passes before vector lowering expert.
451+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
450452
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
451453

452454
{
@@ -510,6 +512,7 @@ void addConvTileAndDecomposeExpertPassPipeline(
510512
addCPUBufferizePasses(funcPassManager);
511513

512514
// Run IREE specific passes before vector lowering expert.
515+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
513516
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
514517

515518
{

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
262262
funcPassManager.addPass(createGPUDistributePass());
263263

264264
// Post bufferization optimizations.
265+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
265266
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
266267
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
267268
funcPassManager.addPass(createCanonicalizerPass());
@@ -439,6 +440,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
439440
funcPassManager.addPass(createTileLargeTensorsPass());
440441
funcPassManager.addPass(createCanonicalizerPass());
441442
funcPassManager.addPass(createCSEPass());
443+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
442444
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
443445
funcPassManager.addPass(IREE::GPU::createCombineBarrierRegionsPass());
444446

@@ -468,6 +470,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
468470
funcPassManager.addPass(createCSEPass());
469471

470472
// Step 9. Remaining post-bufferization optimizations/lowerings.
473+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
471474
funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass());
472475
funcPassManager.addPass(createUnrollAnnotatedLoopsPass());
473476
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
@@ -524,6 +527,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
524527
funcPassManager.addPass(createGPUDistributeScfForPass(options));
525528

526529
// Post bufferization optimizations.
530+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
527531
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
528532
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
529533
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
@@ -544,6 +548,7 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
544548
// Distribute linalg onto warps within the workgroup.
545549
funcPassManager.addPass(
546550
createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/true));
551+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
547552
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
548553
if (pipelineDepth > 1) {
549554
funcPassManager.addPass(createGPUMultiBufferingPass(
@@ -589,6 +594,7 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
589594
funcPassManager.addPass(createCSEPass());
590595

591596
// Hoist loop invariant code to avoid pipelining it.
597+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
592598
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
593599
// Pipeline memory operations.
594600
GPUPipeliningPassOptions pipelieningOptions = {};
@@ -613,6 +619,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(
613619
// Distribute linalg onto warps within the workgroup.
614620
funcPassManager.addPass(
615621
createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/true));
622+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
616623
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
617624
if (pipelineDepth > 1) {
618625
funcPassManager.addPass(createGPUMultiBufferingPass(
@@ -655,6 +662,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(
655662
funcPassManager.addPass(createCSEPass());
656663

657664
// Hoist loop invariant code to avoid pipelining it.
665+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
658666
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
659667
// Pipeline memory operations.
660668
GPUPipeliningPassOptions pipelieningOptions = {};
@@ -882,6 +890,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
882890
funcPassManager.addPass(createGPUTileReductionPass());
883891
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
884892
funcPassManager.addPass(createCSEPass());
893+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
885894

886895
// Linalg -> vector
887896
{
@@ -949,6 +958,7 @@ void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
949958
funcPassManager.addPass(createCanonicalizerPass());
950959
funcPassManager.addPass(createCSEPass());
951960

961+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
952962
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
953963
}
954964

@@ -965,6 +975,7 @@ void addGPUDefaultPassPipeline(OpPassManager &funcPassManager,
965975
funcPassManager.addPass(createCSEPass());
966976

967977
addBufferizePasses(funcPassManager);
978+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
968979
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
969980
}
970981

@@ -981,6 +992,7 @@ void addGPUBaseLoweringPassPipeline(OpPassManager &funcPassManager) {
981992
funcPassManager.addPass(IREE::LinalgExt::createLinalgExtToLoopsPass());
982993
funcPassManager.addPass(createMemrefCopyToLinalgPass());
983994
funcPassManager.addPass(createConvertLinalgToLoopsPass());
995+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
984996
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
985997
funcPassManager.addPass(createCanonicalizerPass());
986998
funcPassManager.addPass(createCSEPass());
@@ -999,6 +1011,7 @@ addLowerAndOptimizeAddressComputationPasses(FunctionLikeNest &funcPassManager) {
9991011
.addPass(memref::createExpandOpsPass)
10001012
.addPass(memref::createFoldMemRefAliasOpsPass)
10011013
.addPass(memref::createExpandStridedMetadataPass)
1014+
.addPass(createPropagateDispatchSizeBoundsPass)
10021015
// Hoist loop invariant variables to give affine decomposition pass the
10031016
// right loop dependencies.
10041017
.addPass(createIREELoopInvariantCodeMotionPass)
@@ -1055,9 +1068,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
10551068
FunctionLikeNest funcPassManager(modulePassManager);
10561069
funcPassManager.addPass(createFoldTensorExtractOpPass)
10571070
.addPass(createLLVMGPUVectorLoweringPass)
1058-
.addPass(createExpandGPUOpsPass)
1059-
// Expose workitem and workgroup counts to range inference later.
1060-
.addPass(createPropagateDispatchSizeBoundsPass);
1071+
.addPass(createExpandGPUOpsPass);
10611072

10621073
// This pass needs to run before SCF -> CF.
10631074
addLowerAndOptimizeAddressComputationPasses(funcPassManager);

compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ static void addLoopMaterializationPasses(OpPassManager &funcPassManager) {
163163
funcPassManager.addPass(IREE::LinalgExt::createLinalgExtToLoopsPass());
164164
funcPassManager.addPass(createMemrefCopyToLinalgPass());
165165
funcPassManager.addPass(createConvertLinalgToLoopsPass());
166+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
166167
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
167168
funcPassManager.addPass(createCanonicalizerPass());
168169
funcPassManager.addPass(createCSEPass());
@@ -394,6 +395,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(
394395
funcPassManager.addPass(
395396
createSPIRVTileAndPromotePass(SPIRVTileAndPromotePassOptions{
396397
/*promoteCMatrix=*/true, /*skipThreadLevel=*/true}));
398+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
397399
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
398400
// Run canonicalization patterns to propagate constant shape sizes after
399401
// removing trip-one loops.
@@ -421,6 +423,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(
421423
funcPassManager.addPass(createGPUReduceBankConflictsPass(options));
422424
}
423425

426+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
424427
// Performs high-level n-D mechanical vectorization. This does not perform
425428
// unrolling or lowering, which is done later.
426429
{
@@ -513,6 +516,7 @@ void addSPIRVMatmulPromoteVectorizePassPipeline(OpPassManager &funcPassManager,
513516
funcPassManager.addPass(createGPUDistributeSharedMemoryCopyPass());
514517
funcPassManager.addPass(createCanonicalizerPass());
515518
funcPassManager.addPass(createCSEPass());
519+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
516520

517521
{
518522
GPUReduceBankConflictsPassOptions options = {};
@@ -532,6 +536,7 @@ void addSPIRVMatmulPromoteVectorizePassPipeline(OpPassManager &funcPassManager,
532536
funcPassManager.addPass(createForOpCanonicalizationPass());
533537
funcPassManager.addPass(createCanonicalizerPass());
534538
funcPassManager.addPass(createCSEPass());
539+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
535540
funcPassManager.addPass(createOptimizeVectorTransferPass());
536541

537542
// Hoist loop invariant code to avoid pipelining it.
@@ -560,6 +565,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &funcPassManager) {
560565
funcPassManager.addPass(createGPUTileReductionPass());
561566
funcPassManager.addPass(createCanonicalizerPass());
562567
funcPassManager.addPass(createCSEPass());
568+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
563569

564570
// Performs high-level n-D mechanical vectorization. This does not perform
565571
// unrolling or lowering, which is done later.
@@ -588,6 +594,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &funcPassManager) {
588594

589595
// Perform various vector-level cross-op optimizations like load-store
590596
// forwarding, shape casting and casting op cancelling.
597+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
591598
funcPassManager.addPass(createOptimizeVectorTransferPass());
592599

593600
// Simplify the IR for vector distribution.

compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ func.func @warp_reduction_dispatch() attributes {hal.executable.target = #execut
142142

143143
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
144144

145-
// CHECK-DAG: %[[WGIDX:.+]] = hal.interface.workgroup.id[0] : index
146-
// CHECK-DAG: %[[WGIDY:.+]] = hal.interface.workgroup.id[1] : index
145+
// CHECK-DAG: %[[WGIDX:.+]] = hal.interface.workgroup.id[0] upper_bound 65535 : index
146+
// CHECK-DAG: %[[WGIDY:.+]] = hal.interface.workgroup.id[1] upper_bound 65535 : index
147147
// CHECK-DAG: %[[TIDX:.+]] = gpu.thread_id x
148148

149149
// CHECK-DAG: %[[SPAN0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)

compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ void addVMVXDefaultPassPipeline(OpPassManager &funcPassManager,
7171
addCPUBufferizePasses(funcPassManager);
7272

7373
// Cleanup the IR that may now have unused loops.
74+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
7475
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
7576

7677
// Convert buffer-level microkernels.

compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3039,9 +3039,10 @@ class HAL_InterfaceWorkgroupOp<string mnemonic, list<Trait> traits = []>
30393039
let results = (outs HAL_Dim:$result);
30403040

30413041
let builders = [
3042-
OpBuilder<(ins "unsigned":$dim),
3042+
OpBuilder<(ins "unsigned":$dim, CArg<"std::optional<::llvm::APInt>", "std::nullopt">:$upper_bound),
30433043
[{
3044-
build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim), ::mlir::IntegerAttr{});
3044+
build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim),
3045+
upper_bound.has_value() ? $_builder.getIndexAttr(upper_bound->getSExtValue()) : ::mlir::IntegerAttr{});
30453046
}]>,
30463047
];
30473048

0 commit comments

Comments
 (0)