@@ -1008,11 +1008,6 @@ struct MoveFuncBodyToWarpExecuteOnLane0
10081008 rewriter.setInsertionPointAfter (warpOp);
10091009 rewriter.create <gpu::ReturnOp>(newGpuFunc.getLoc (), warpOp.getResults ());
10101010 rewriter.replaceOp (gpuFuncOp, newGpuFunc);
1011- // At this point, we have moved the entire function body inside the warpOp.
1012- // Now move any scalar uniform code outside of the warpOp (like GPU index
1013- // ops, scalar constants, etc.). This will simplify the later lowering and
1014- // avoid custom patterns for these ops.
1015- vector::moveScalarUniformCode (warpOp);
10161011 return success ();
10171012 }
10181013};
@@ -1468,6 +1463,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
14681463 signalPassFailure ();
14691464 return ;
14701465 }
1466+ // At this point, we have moved the entire function body inside the warpOp.
1467+ // Now move any scalar uniform code outside of the warpOp (like GPU index
1468+ // ops, scalar constants, etc.). This will simplify the later lowering and
1469+ // avoid custom patterns for these ops.
1470+ getOperation ()->walk ([&](Operation *op) {
1471+ if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
1472+ vector::moveScalarUniformCode (warpOp);
1473+ }
1474+ });
14711475 }
14721476 // Finally, do the SIMD to SIMT distribution.
14731477 RewritePatternSet patterns (&getContext ());
0 commit comments