File tree Expand file tree Collapse file tree 1 file changed +12
-2
lines changed
mlir/lib/Dialect/XeGPU/Transforms Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -1008,12 +1008,22 @@ struct WgToSgVectorShapeCastOp
10081008 if (dstRank > srcRank) {
10091009 // Expanding dims: input operand's layout must be a SliceAttr
10101010 auto srcLayout = xegpu::getDistributeLayoutAttr (op.getSource ());
1011- if (!srcLayout || !isa<xegpu::SliceAttr>(srcLayout))
1011+ auto srcSliceAttr = cast<xegpu::SliceAttr>(srcLayout);
1012+ if (!srcLayout || !srcSliceAttr)
1013+ return failure ();
1014+ auto resLayout = xegpu::getDistributeLayoutAttr (op.getResult ());
1015+ // Check srcLayout is a slice attr on top of resLayout
1016+ if (srcSliceAttr.getParent () != resLayout)
10121017 return failure ();
10131018 } else if (dstRank < srcRank) {
10141019 // Reducing dims: result's layout must be a SliceAttr
10151020 auto resLayout = xegpu::getDistributeLayoutAttr (op.getResult ());
1016- if (!resLayout || !isa<xegpu::SliceAttr>(resLayout))
1021+ auto resSliceAttr = cast<xegpu::SliceAttr>(resLayout);
1022+ auto srcLayout = xegpu::getDistributeLayoutAttr (op.getSource ());
1023+ if (!resSliceAttr || !srcLayout)
1024+ return failure ();
1025+ // Check resLayout is a sliced attr from srcLayout
1026+ if (resSliceAttr.getParent () != srcLayout)
10171027 return failure ();
10181028 }
10191029
You can’t perform that action at this time.
0 commit comments