Skip to content

Commit c93090f

Browse files
committed
Add check
1 parent bdbf14f commit c93090f

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,14 @@ struct WgToSgLoadGatherOpWithOffset
785785

786786
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
787787

788+
// The offsets need to be distributed
789+
if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
790+
.getShape() !=
791+
dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
792+
return rewriter.notifyMatchFailure(op,
793+
"offsets have not been distributed");
794+
}
795+
788796
SmallVector<Value> newLoadOps;
789797
auto chunkSizeAttr =
790798
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
@@ -824,6 +832,14 @@ struct WgToSgStoreScatterOpWithOffset
824832
if (!layout || !layout.isForWorkgroup())
825833
return failure();
826834

835+
// The offsets need to be distributed
836+
if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
837+
.getShape() !=
838+
dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
839+
return rewriter.notifyMatchFailure(op,
840+
"offsets have not been distributed");
841+
}
842+
827843
auto chunkSizeOpt = op.getChunkSize();
828844
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
829845
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);

0 commit comments

Comments
 (0)