Skip to content

Commit d0546b2

Browse files
committed
Add check
1 parent 77f3261 commit d0546b2

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,12 +1008,22 @@ struct WgToSgVectorShapeCastOp
10081008
if (dstRank > srcRank) {
10091009
// Expanding dims: input operand's layout must be a SliceAttr
10101010
auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource());
1011-
if (!srcLayout || !isa<xegpu::SliceAttr>(srcLayout))
1011+
auto srcSliceAttr = cast<xegpu::SliceAttr>(srcLayout);
1012+
if (!srcLayout || !srcSliceAttr)
1013+
return failure();
1014+
auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult());
1015+
// Check srcLayout is a slice attr on top of resLayout
1016+
if (srcSliceAttr.getParent() != resLayout)
10121017
return failure();
10131018
} else if (dstRank < srcRank) {
10141019
// Reducing dims: result's layout must be a SliceAttr
10151020
auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult());
1016-
if (!resLayout || !isa<xegpu::SliceAttr>(resLayout))
1021+
auto resSliceAttr = cast<xegpu::SliceAttr>(resLayout);
1022+
auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource());
1023+
if (!resSliceAttr || !srcLayout)
1024+
return failure();
1025+
// Check resLayout is a sliced attr from srcLayout
1026+
if (resSliceAttr.getParent() != srcLayout)
10171027
return failure();
10181028
}
10191029

0 commit comments

Comments
 (0)