Skip to content

Commit a7b780d

Browse files
committed
Feedback
1 parent c93090f commit a7b780d

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -786,9 +786,12 @@ struct WgToSgLoadGatherOpWithOffset
786786
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
787787

788788
// The offsets need to be distributed
789-
if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
790-
.getShape() !=
791-
dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
789+
auto offsetsVecType =
790+
dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
791+
auto maskVecType =
792+
dyn_cast<VectorType>(adaptor.getMask().front().getType());
793+
if (!offsetsVecType || !maskVecType ||
794+
offsetsVecType.getShape() != maskVecType.getShape()) {
792795
return rewriter.notifyMatchFailure(op,
793796
"offsets have not been distributed");
794797
}
@@ -833,9 +836,12 @@ struct WgToSgStoreScatterOpWithOffset
833836
return failure();
834837

835838
// The offsets need to be distributed
836-
if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
837-
.getShape() !=
838-
dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
839+
auto offsetsVecType =
840+
dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
841+
auto maskVecType =
842+
dyn_cast<VectorType>(adaptor.getMask().front().getType());
843+
if (!offsetsVecType || !maskVecType ||
844+
offsetsVecType.getShape() != maskVecType.getShape()) {
839845
return rewriter.notifyMatchFailure(op,
840846
"offsets have not been distributed");
841847
}

0 commit comments

Comments
 (0)