@@ -812,16 +812,16 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
812812}
813813
814814void XeGPUSubgroupDistributePass::runOnOperation () {
815- // Attach layout to operands.
815+ // Step 1: Attach layout to op operands.
816+ // TODO: Following assumptions are made:
817+ // 1) It is assumed that there are no layout conflicts.
818+ // 2) Any existing layout attributes attached to the operands are ignored.
816819 Operation *op = getOperation ();
817820 op->walk ([&](Operation *op) {
818821 for (OpOperand &operand : op->getOpOperands ()) {
819822 // Layouts are needed for vector type only.
820823 if (!isa<VectorType>(operand.get ().getType ()))
821824 continue ;
822- // If the operand already has a layout, skip it.
823- if (xegpu::getLayoutAttr (operand))
824- continue ;
825825
826826 xegpu::LayoutAttr layout = xegpu::getLayoutAttr (operand);
827827 if (!layout) {
@@ -833,8 +833,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
833833 xegpu::setLayoutAttr (operand, layout);
834834 }
835835 });
836- // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
837- // operation.
836+ // Step 2: Move all operations of a GPU function inside
837+ // gpu.warp_execute_on_lane_0 operation.
838838 {
839839 RewritePatternSet patterns (&getContext ());
840840 patterns.add <MoveFuncBodyToWarpExecuteOnLane0>(&getContext ());
@@ -853,7 +853,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
853853 }
854854 });
855855 }
856- // Apply subgroup to workitem distribution patterns.
856+ // Step 3: Finally, Apply subgroup to workitem distribution patterns.
857857 RewritePatternSet patterns (&getContext ());
858858 xegpu::populateXeGPUSubgroupDistributePatterns (patterns);
859859 // TODO: distributionFn and shuffleFn are not used at this point.
@@ -874,8 +874,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
874874 return ;
875875 }
876876
877- // Clean up UnrealizedConversionCastOps that were inserted due to tensor
878- // desc type mismatches created by using upstream distribution patterns
877+ // Step 4: Clean up UnrealizedConversionCastOps that were inserted due to
878+ // tensor desc type mismatches created by using upstream distribution patterns
879879 // (scf.for)
880880 getOperation ()->walk ([&](mlir::UnrealizedConversionCastOp op) {
881881 // We are only interested in UnrealizedConversionCastOps there were added
0 commit comments