@@ -136,19 +136,6 @@ static Value resolveDistributedTy(Value orig, T expected,
136136 return orig;
137137}
138138
139- // / Helper function to filter out the temporary layout attributes attached
140- // / during the layout assignment process. These are not needed after going to
141- // / SIMT.
142- static SmallVector<NamedAttribute>
143- removeTemporaryLayoutAttributes (ArrayRef<NamedAttribute> attrs) {
144- SmallVector<NamedAttribute> newAttrs;
145- for (NamedAttribute attr : attrs) {
146- if (!isa<xegpu::LayoutAttr>(attr.getValue ()))
147- newAttrs.push_back (attr);
148- }
149- return newAttrs;
150- }
151-
152139// / Helper function to check if the layout is packed. Layout is packed if it is
153140// / 2D and lane_data[0] != 1 (data packed from col dimension).
154141static bool hasPackedLayout (xegpu::LayoutAttr layout) {
@@ -412,9 +399,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
412399 resolveDistributedTy (newWarpOp.getResult (newRetIndices[1 ]),
413400 distributedTensorDescTy, rewriter));
414401
415- rewriter.create <xegpu::StoreNdOp>(
416- newWarpOp.getLoc (), TypeRange{}, newStoreOperands,
417- removeTemporaryLayoutAttributes (storeOp-> getAttrs ()) );
402+ auto newStoreOp = rewriter.create <xegpu::StoreNdOp>(
403+ newWarpOp.getLoc (), TypeRange{}, newStoreOperands, storeOp-> getAttrs ());
404+ xegpu::removeLayoutAttrs (newStoreOp );
418405 rewriter.eraseOp (storeOp);
419406 return success ();
420407 }
@@ -508,7 +495,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
508495 newWarpOp.getLoc (), loadNdDistValueTyOrFailure.value (),
509496 resolveDistributedTy (newWarpOp->getResult (newRetIndices[0 ]),
510497 distributedTensorDescTy, rewriter),
511- removeTemporaryLayoutAttributes (loadOp->getAttrs ()));
498+ loadOp->getAttrs ());
499+ xegpu::removeLayoutAttrs (newLoadOp);
512500 // Set the packed attribute if the layout requires it.
513501 newLoadOp.setPacked (hasPackedLayout (layout));
514502 Value distributedVal = newWarpOp.getResult (operandIdx);
@@ -639,14 +627,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
639627 resolveDistributedTy (newWarpOp.getResult (newRetIndices[i]),
640628 newDpasOperandExpectedTypes[i], rewriter));
641629 }
642- Value newDpasOp = rewriter.create <xegpu::DpasOp>(
643- newWarpOp->getLoc (), distributedResultTy, newDpasOperands,
644- removeTemporaryLayoutAttributes (dpasOp->getAttrs ()));
630+ auto newDpasOp =
631+ rewriter.create <xegpu::DpasOp>(newWarpOp->getLoc (), distributedResultTy,
632+ newDpasOperands, dpasOp->getAttrs ());
633+ xegpu::removeLayoutAttrs (newDpasOp);
645634 Value distributedVal = newWarpOp.getResult (operandIdx);
646635 // Resolve the output type.
647- newDpasOp = resolveDistributedTy (
648- newDpasOp, distResultTypeByWarpOpOrFailure.value (), rewriter);
649- rewriter.replaceAllUsesWith (distributedVal, newDpasOp);
636+ Value typeResolved =
637+ resolveDistributedTy (newDpasOp.getResult (),
638+ distResultTypeByWarpOpOrFailure.value (), rewriter);
639+ rewriter.replaceAllUsesWith (distributedVal, typeResolved);
650640 return success ();
651641 }
652642};
@@ -726,14 +716,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
726716 }
727717 }
728718 // Create a new update op outside the warp op.
729- Value newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
719+ auto newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
730720 newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands,
731- removeTemporaryLayoutAttributes (updateOp->getAttrs ()));
721+ updateOp->getAttrs ());
722+ xegpu::removeLayoutAttrs (newUpdateOp);
732723 Value distributedVal = newWarpOp.getResult (operandIdx);
733724 // Resolve the distributed type with the original type.
734- newUpdateOp =
735- resolveDistributedTy ( newUpdateOp, distributedVal.getType (), rewriter);
736- rewriter.replaceAllUsesWith (distributedVal, newUpdateOp );
725+ Value typeResolved = resolveDistributedTy (
726+ newUpdateOp. getResult () , distributedVal.getType (), rewriter);
727+ rewriter.replaceAllUsesWith (distributedVal, typeResolved );
737728 return success ();
738729 }
739730};
@@ -792,9 +783,10 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
792783 rewriter.setInsertionPointAfter (newWarpOp);
793784 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy (
794785 newWarpOp.getResult (newRetIndices[0 ]), newTensorDescTy, rewriter)};
795- rewriter.create <xegpu::PrefetchNdOp>(
796- newWarpOp.getLoc (), TypeRange{}, newPrefetchOperands,
797- removeTemporaryLayoutAttributes (prefetchOp->getAttrs ()));
786+ rewriter.create <xegpu::PrefetchNdOp>(newWarpOp.getLoc (), TypeRange{},
787+ newPrefetchOperands,
788+ prefetchOp->getAttrs ());
789+ xegpu::removeLayoutAttrs (prefetchOp);
798790 rewriter.eraseOp (prefetchOp);
799791 return success ();
800792 }
0 commit comments