@@ -262,6 +262,7 @@ void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
262262 funcPassManager.addPass (createGPUDistributePass ());
263263
264264 // Post bufferization optimizations.
265+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
265266 funcPassManager.addPass (createIREELoopInvariantCodeMotionPass ());
266267 funcPassManager.addPass (memref::createFoldMemRefAliasOpsPass ());
267268 funcPassManager.addPass (createCanonicalizerPass ());
@@ -439,6 +440,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
439440 funcPassManager.addPass (createTileLargeTensorsPass ());
440441 funcPassManager.addPass (createCanonicalizerPass ());
441442 funcPassManager.addPass (createCSEPass ());
443+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
442444 funcPassManager.addPass (createIREELoopInvariantCodeMotionPass ());
443445 funcPassManager.addPass (IREE::GPU::createCombineBarrierRegionsPass ());
444446
@@ -468,6 +470,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
468470 funcPassManager.addPass (createCSEPass ());
469471
470472 // Step 9. Remaining post-bufferization optimizations/lowerings.
473+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
471474 funcPassManager.addPass (IREE::GPU::createLowerIREEGPUOpsPass ());
472475 funcPassManager.addPass (createUnrollAnnotatedLoopsPass ());
473476 funcPassManager.addPass (createIREELoopInvariantCodeMotionPass ());
@@ -524,6 +527,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
524527 funcPassManager.addPass (createGPUDistributeScfForPass (options));
525528
526529 // Post bufferization optimizations.
530+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
527531 funcPassManager.addPass (createIREELoopInvariantCodeMotionPass ());
528532 funcPassManager.addPass (memref::createFoldMemRefAliasOpsPass ());
529533 funcPassManager.addPass (createConfigTrackingCanonicalizerPass ());
@@ -544,6 +548,7 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
544548 // Distribute linalg onto warps within the workgroup.
545549 funcPassManager.addPass (
546550 createLLVMGPUTileAndDistributePass (/* distributeToWarp=*/ true ));
551+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
547552 funcPassManager.addPass (createRemoveSingleIterationLoopPass ());
548553 if (pipelineDepth > 1 ) {
549554 funcPassManager.addPass (createGPUMultiBufferingPass (
@@ -589,6 +594,7 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
589594 funcPassManager.addPass (createCSEPass ());
590595
591596 // Hoist loop invariant code to avoid pipelining it.
597+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
592598 funcPassManager.addPass (createIREELoopInvariantCodeMotionPass ());
593599 // Pipeline memory operations.
594600 GPUPipeliningPassOptions pipelieningOptions = {};
@@ -613,6 +619,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(
613619 // Distribute linalg onto warps within the workgroup.
614620 funcPassManager.addPass (
615621 createLLVMGPUTileAndDistributePass (/* distributeToWarp=*/ true ));
622+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
616623 funcPassManager.addPass (createRemoveSingleIterationLoopPass ());
617624 if (pipelineDepth > 1 ) {
618625 funcPassManager.addPass (createGPUMultiBufferingPass (
@@ -655,6 +662,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(
655662 funcPassManager.addPass (createCSEPass ());
656663
657664 // Hoist loop invariant code to avoid pipelining it.
665+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
658666 funcPassManager.addPass (createIREELoopInvariantCodeMotionPass ());
659667 // Pipeline memory operations.
660668 GPUPipeliningPassOptions pipelieningOptions = {};
@@ -882,6 +890,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
882890 funcPassManager.addPass (createGPUTileReductionPass ());
883891 funcPassManager.addPass (createConfigTrackingCanonicalizerPass ());
884892 funcPassManager.addPass (createCSEPass ());
893+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
885894
886895 // Linalg -> vector
887896 {
@@ -949,6 +958,7 @@ void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
949958 funcPassManager.addPass (createCanonicalizerPass ());
950959 funcPassManager.addPass (createCSEPass ());
951960
961+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
952962 funcPassManager.addPass (createRemoveSingleIterationLoopPass ());
953963}
954964
@@ -965,6 +975,7 @@ void addGPUDefaultPassPipeline(OpPassManager &funcPassManager,
965975 funcPassManager.addPass (createCSEPass ());
966976
967977 addBufferizePasses (funcPassManager);
978+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
968979 funcPassManager.addPass (createRemoveSingleIterationLoopPass ());
969980}
970981
@@ -981,6 +992,7 @@ void addGPUBaseLoweringPassPipeline(OpPassManager &funcPassManager) {
981992 funcPassManager.addPass (IREE::LinalgExt::createLinalgExtToLoopsPass ());
982993 funcPassManager.addPass (createMemrefCopyToLinalgPass ());
983994 funcPassManager.addPass (createConvertLinalgToLoopsPass ());
995+ funcPassManager.addPass (createPropagateDispatchSizeBoundsPass ());
984996 funcPassManager.addPass (createRemoveSingleIterationLoopPass ());
985997 funcPassManager.addPass (createCanonicalizerPass ());
986998 funcPassManager.addPass (createCSEPass ());
@@ -999,6 +1011,7 @@ addLowerAndOptimizeAddressComputationPasses(FunctionLikeNest &funcPassManager) {
9991011 .addPass (memref::createExpandOpsPass)
10001012 .addPass (memref::createFoldMemRefAliasOpsPass)
10011013 .addPass (memref::createExpandStridedMetadataPass)
1014+ .addPass (createPropagateDispatchSizeBoundsPass)
10021015 // Hoist loop invariant variables to give affine decomposition pass the
10031016 // right loop dependencies.
10041017 .addPass (createIREELoopInvariantCodeMotionPass)
@@ -1055,9 +1068,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
10551068 FunctionLikeNest funcPassManager (modulePassManager);
10561069 funcPassManager.addPass (createFoldTensorExtractOpPass)
10571070 .addPass (createLLVMGPUVectorLoweringPass)
1058- .addPass (createExpandGPUOpsPass)
1059- // Expose workitem and workgroup counts to range inference later.
1060- .addPass (createPropagateDispatchSizeBoundsPass);
1071+ .addPass (createExpandGPUOpsPass);
10611072
10621073 // This pass needs to run before SCF -> CF.
10631074 addLowerAndOptimizeAddressComputationPasses (funcPassManager);
0 commit comments