Skip to content

Commit 7029025

Browse files
committed
save work
1 parent 519d02a commit 7029025

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,11 +1008,6 @@ struct MoveFuncBodyToWarpExecuteOnLane0
10081008
rewriter.setInsertionPointAfter(warpOp);
10091009
rewriter.create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
10101010
rewriter.replaceOp(gpuFuncOp, newGpuFunc);
1011-
// At this point, we have moved the entire function body inside the warpOp.
1012-
// Now move any scalar uniform code outside of the warpOp (like GPU index
1013-
// ops, scalar constants, etc.). This will simplify the later lowering and
1014-
// avoid custom patterns for these ops.
1015-
vector::moveScalarUniformCode(warpOp);
10161011
return success();
10171012
}
10181013
};
@@ -1468,6 +1463,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
14681463
signalPassFailure();
14691464
return;
14701465
}
1466+
// At this point, we have moved the entire function body inside the warpOp.
1467+
// Now move any scalar uniform code outside of the warpOp (like GPU index
1468+
// ops, scalar constants, etc.). This will simplify the later lowering and
1469+
// avoid custom patterns for these ops.
1470+
getOperation()->walk([&](Operation *op) {
1471+
if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
1472+
vector::moveScalarUniformCode(warpOp);
1473+
}
1474+
});
14711475
}
14721476
// Finally, do the SIMD to SIMT distribution.
14731477
RewritePatternSet patterns(&getContext());

0 commit comments

Comments
 (0)