Skip to content

Commit 8b490f6

Browse files
committed
refine verifier for gather and Scatter op
1 parent f5d7875 commit 8b490f6

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,9 @@ LogicalResult LoadGatherOp::verify() {
541541
if (tdescShape[0] != maskShape[0])
542542
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
543543

544+
auto chunkSize = tdescTy.getChunkSize();
544545
// for SIMT code, the value should be 1D vector with size of chunkSize.
545546
if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
546-
auto chunkSize = tdescTy.getChunkSize();
547547
if (valueTy.getNumElements() != chunkSize) {
548548
return emitOpError()
549549
<< "Result shape " << makeString(valueShape)
@@ -557,6 +557,11 @@ LogicalResult LoadGatherOp::verify() {
557557
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
558558
}
559559
return success();
560+
} else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
561+
// for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
562+
// it is a valid SIMT code if chunkSize happens to be the same as
563+
// subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
564+
return success();
560565
}
561566

562567
// For SIMD code verification.
@@ -602,9 +607,9 @@ LogicalResult StoreScatterOp::verify() {
602607
if (tdescShape[0] != maskShape[0])
603608
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
604609

610+
auto chunkSize = tdescTy.getChunkSize();
605611
// for SIMT code, the value should be 1D vector with size of chunkSize.
606612
if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
607-
auto chunkSize = tdescTy.getChunkSize();
608613
if (valueTy.getNumElements() != chunkSize) {
609614
return emitOpError()
610615
<< "Value shape " << makeString(valueShape)
@@ -618,6 +623,11 @@ LogicalResult StoreScatterOp::verify() {
618623
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
619624
}
620625
return success();
626+
} else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
627+
// for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
628+
// it is a valid SIMT code if chunkSize happens to be the same as
629+
// subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
630+
return success();
621631
}
622632

623633
// for SIMD code verification.

0 commit comments

Comments
 (0)