@@ -487,15 +487,10 @@ struct WgToSgVectorBroadcastOp
487487 for (auto operand : adaptor.getOperands ().front ()) {
488488 auto newBroadcast = vector::BroadcastOp::create (rewriter, op.getLoc (),
489489 newResultType, operand);
490- if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
491- if (sliceAttr.isForSubgroup ())
492- xegpu::setDistributeLayoutAttr (newBroadcast->getResult (0 ),
493- sliceAttr.dropSgLayoutAndData ());
494- } else if (auto layoutAttr =
495- dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
496- if (auto newLayout = layoutAttr.dropSgLayoutAndData ())
497- xegpu::setDistributeLayoutAttr (newBroadcast->getResult (0 ), newLayout);
498- }
490+ if (!layout.getLaneLayoutAsInt ().empty ())
491+ xegpu::setDistributeLayoutAttr (newBroadcast->getResult (0 ),
492+ layout.dropSgLayoutAndData ());
493+
499494 newBroadcastOps.push_back (newBroadcast.getResult ());
500495 }
501496 rewriter.replaceOpWithMultiple (op, {newBroadcastOps});
@@ -549,13 +544,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
549544 // Copy all attributes, but update "layout_result_0" to drop
550545 // sgLayout/sgData
551546 for (auto attr : op->getAttrs ()) {
552- if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue ())) {
553- if (auto newLayout = layout.dropSgLayoutAndData ())
554- state.addAttribute (attr.getName (), newLayout);
555- } else if (auto sliceAttr =
556- dyn_cast<xegpu::SliceAttr>(attr.getValue ())) {
557- if (sliceAttr.isForSubgroup ())
558- state.addAttribute (attr.getName (), sliceAttr.dropSgLayoutAndData ());
547+ if (auto layout =
548+ dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue ())) {
549+ if (!layout.getLaneLayoutAsInt ().empty ())
550+ state.addAttribute (attr.getName (), layout.dropSgLayoutAndData ());
559551 } else {
560552 state.addAttribute (attr.getName (), attr.getValue ());
561553 }
@@ -746,15 +738,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
746738 auto sgAttr = DenseElementsAttr::get (newType, singleVal);
747739 auto cstOp =
748740 arith::ConstantOp::create (rewriter, op.getLoc (), newType, sgAttr);
749- if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
750- if (sliceAttr.isForSubgroup ())
751- xegpu::setDistributeLayoutAttr (cstOp->getResult (0 ),
752- sliceAttr.dropSgLayoutAndData ());
753- } else if (auto layoutAttr =
754- dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
755- if (auto newLayout = layoutAttr.dropSgLayoutAndData ())
756- xegpu::setDistributeLayoutAttr (cstOp->getResult (0 ), newLayout);
757- }
741+ if (!layout.getLaneLayoutAsInt ().empty ())
742+ xegpu::setDistributeLayoutAttr (cstOp->getResult (0 ),
743+ layout.dropSgLayoutAndData ());
758744 SmallVector<Value> newConsts (count, cstOp);
759745
760746 rewriter.replaceOpWithMultiple (op, {newConsts});
@@ -983,15 +969,9 @@ struct WgToSgVectorShapeCastOp
983969 for (auto src : adaptor.getSource ()) {
984970 auto newShapeCast =
985971 rewriter.create <vector::ShapeCastOp>(op.getLoc (), newResultType, src);
986- if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
987- if (sliceAttr.isForSubgroup ())
988- xegpu::setDistributeLayoutAttr (newShapeCast->getResult (0 ),
989- sliceAttr.dropSgLayoutAndData ());
990- } else if (auto layoutAttr =
991- dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
992- if (auto newLayout = layoutAttr.dropSgLayoutAndData ())
993- xegpu::setDistributeLayoutAttr (newShapeCast->getResult (0 ), newLayout);
994- }
972+ if (!layout.getLaneLayoutAsInt ().empty ())
973+ xegpu::setDistributeLayoutAttr (newShapeCast->getResult (0 ),
974+ layout.dropSgLayoutAndData ());
995975 newShapeCastOps.push_back (newShapeCast.getResult ());
996976 }
997977
0 commit comments