Skip to content

Commit 76671e2

Browse files
committed
address comments
1 parent d43f0ec commit 76671e2

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -812,16 +812,16 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
812812
}
813813

814814
void XeGPUSubgroupDistributePass::runOnOperation() {
815-
// Attach layout to operands.
815+
// Step 1: Attach layout to op operands.
816+
// TODO: Following assumptions are made:
817+
// 1) It is assumed that there are no layout conflicts.
818+
// 2) Any existing layout attributes attached to the operands are ignored.
816819
Operation *op = getOperation();
817820
op->walk([&](Operation *op) {
818821
for (OpOperand &operand : op->getOpOperands()) {
819822
// Layouts are needed for vector type only.
820823
if (!isa<VectorType>(operand.get().getType()))
821824
continue;
822-
// If the operand already has a layout, skip it.
823-
if (xegpu::getLayoutAttr(operand))
824-
continue;
825825

826826
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
827827
if (!layout) {
@@ -833,8 +833,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
833833
xegpu::setLayoutAttr(operand, layout);
834834
}
835835
});
836-
// Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
837-
// operation.
836+
// Step 2: Move all operations of a GPU function inside
837+
// gpu.warp_execute_on_lane_0 operation.
838838
{
839839
RewritePatternSet patterns(&getContext());
840840
patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
@@ -853,7 +853,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
853853
}
854854
});
855855
}
856-
// Apply subgroup to workitem distribution patterns.
856+
// Step 3: Finally, Apply subgroup to workitem distribution patterns.
857857
RewritePatternSet patterns(&getContext());
858858
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
859859
// TODO: distributionFn and shuffleFn are not used at this point.
@@ -874,8 +874,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
874874
return;
875875
}
876876

877-
// Clean up UnrealizedConversionCastOps that were inserted due to tensor
878-
// desc type mismatches created by using upstream distribution patterns
877+
// Step 4: Clean up UnrealizedConversionCastOps that were inserted due to
878+
// tensor desc type mismatches created by using upstream distribution patterns
879879
// (scf.for)
880880
getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
881881
// We are only interested in UnrealizedConversionCastOps there were added

0 commit comments

Comments
 (0)