Skip to content

Commit 8cb5ebe

Browse files
committed
Clean up check
1 parent 5e9f3df commit 8cb5ebe

File tree

1 file changed

+14
-34
lines changed

1 file changed

+14
-34
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -487,15 +487,10 @@ struct WgToSgVectorBroadcastOp
487487
for (auto operand : adaptor.getOperands().front()) {
488488
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
489489
newResultType, operand);
490-
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
491-
if (sliceAttr.isForSubgroup())
492-
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
493-
sliceAttr.dropSgLayoutAndData());
494-
} else if (auto layoutAttr =
495-
dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
496-
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
497-
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout);
498-
}
490+
if (!layout.getLaneLayoutAsInt().empty())
491+
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
492+
layout.dropSgLayoutAndData());
493+
499494
newBroadcastOps.push_back(newBroadcast.getResult());
500495
}
501496
rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
@@ -549,13 +544,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
549544
// Copy all attributes, but update "layout_result_0" to drop
550545
// sgLayout/sgData
551546
for (auto attr : op->getAttrs()) {
552-
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
553-
if (auto newLayout = layout.dropSgLayoutAndData())
554-
state.addAttribute(attr.getName(), newLayout);
555-
} else if (auto sliceAttr =
556-
dyn_cast<xegpu::SliceAttr>(attr.getValue())) {
557-
if (sliceAttr.isForSubgroup())
558-
state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData());
547+
if (auto layout =
548+
dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
549+
if (!layout.getLaneLayoutAsInt().empty())
550+
state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
559551
} else {
560552
state.addAttribute(attr.getName(), attr.getValue());
561553
}
@@ -746,15 +738,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
746738
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
747739
auto cstOp =
748740
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
749-
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
750-
if (sliceAttr.isForSubgroup())
751-
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
752-
sliceAttr.dropSgLayoutAndData());
753-
} else if (auto layoutAttr =
754-
dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
755-
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
756-
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
757-
}
741+
if (!layout.getLaneLayoutAsInt().empty())
742+
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
743+
layout.dropSgLayoutAndData());
758744
SmallVector<Value> newConsts(count, cstOp);
759745

760746
rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -983,15 +969,9 @@ struct WgToSgVectorShapeCastOp
983969
for (auto src : adaptor.getSource()) {
984970
auto newShapeCast =
985971
rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
986-
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
987-
if (sliceAttr.isForSubgroup())
988-
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
989-
sliceAttr.dropSgLayoutAndData());
990-
} else if (auto layoutAttr =
991-
dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
992-
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
993-
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout);
994-
}
972+
if (!layout.getLaneLayoutAsInt().empty())
973+
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
974+
layout.dropSgLayoutAndData());
995975
newShapeCastOps.push_back(newShapeCast.getResult());
996976
}
997977

0 commit comments

Comments
 (0)