Skip to content

Commit 775d039

Browse files
committed
refine verifier for gather/scatter
1 parent 2159119 commit 775d039

File tree

2 files changed

+22
-44
lines changed

2 files changed

+22
-44
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -547,38 +547,27 @@ LogicalResult LoadGatherOp::verify() {
547547
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
548548

549549
auto chunkSize = tdescTy.getChunkSize();
550-
// for SIMT code, the value should be 1D vector with size of chunkSize.
551-
if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
552-
if (valueTy.getNumElements() != chunkSize) {
550+
551+
// a valid shape for SIMT case
552+
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
553+
if (tdescTy.getLayoutAttr())
553554
return emitOpError()
554-
<< "Result shape " << makeString(valueShape)
555-
<< " is not a valid distribution for tensor descriptor "
556-
<< tdescTy;
557-
} else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
558-
if (tdescTy.getLayoutAttr())
559-
return emitOpError()
560-
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
561-
if (getTransposeAttr())
562-
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
563-
}
564-
return success();
565-
} else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
566-
// for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
567-
// it is a valid SIMT code if chunkSize happens to be the same as
568-
// subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
555+
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
556+
if (getTransposeAttr())
557+
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
569558
return success();
570559
}
571560

572-
// For SIMD code verification.
573-
if (tdescTy.getRank() == 2) {
561+
if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
574562
if (!getTransposeAttr())
575563
return emitOpError("load of rank-2 tensor has to be transposed.");
576564
transpose({1, 0}, tdescShape);
577565
}
578566

579567
if (tdescShape != valueShape)
580568
return emitOpError() << "Result shape " << makeString(valueShape)
581-
<< " is not consistent with tensor descriptor "
569+
<< " is neither a valid distribution for SIMT nor "
570+
"consistent with the tensor descriptor for SIMD "
582571
<< tdescTy;
583572
return success();
584573
}
@@ -613,38 +602,27 @@ LogicalResult StoreScatterOp::verify() {
613602
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
614603

615604
auto chunkSize = tdescTy.getChunkSize();
616-
// for SIMT code, the value should be 1D vector with size of chunkSize.
617-
if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
618-
if (valueTy.getNumElements() != chunkSize) {
605+
606+
// a valid shape for SIMT case
607+
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
608+
if (tdescTy.getLayoutAttr())
619609
return emitOpError()
620-
<< "Value shape " << makeString(valueShape)
621-
<< " is not a valid distribution for tensor descriptor "
622-
<< tdescTy;
623-
} else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
624-
if (tdescTy.getLayoutAttr())
625-
return emitOpError()
626-
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
627-
if (getTransposeAttr())
628-
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
629-
}
630-
return success();
631-
} else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
632-
// for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
633-
// it is a valid SIMT code if chunkSize happens to be the same as
634-
// subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
610+
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
611+
if (getTransposeAttr())
612+
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
635613
return success();
636614
}
637615

638-
// for SIMD code verification.
639-
if (tdescTy.getRank() == 2) {
616+
if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
640617
if (!getTransposeAttr())
641618
return emitOpError("Store of a rank-2 tensor has to be transposed.");
642619
transpose({1, 0}, tdescShape);
643620
}
644621

645622
if (tdescShape != valueShape)
646623
return emitOpError() << "Value shape " << makeString(valueShape)
647-
<< " is not consistent with tensor descriptor "
624+
<< " is neither a valid distribution for SIMT nor "
625+
"consistent with the tensor descriptor for SIMD "
648626
<< tdescTy;
649627

650628
return success();

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func.func @test_load_gather_simt_1(%src: ui64) {
255255
%0 = arith.constant dense<1>: vector<4xi1>
256256
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
257257
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
258-
// expected-error@+1 {{Result shape [6] is not a valid distribution for tensor descriptor}}
258+
// expected-error@+1 {{Result shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
259259
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<6xf32>
260260
return
261261
}
@@ -266,7 +266,7 @@ func.func @test_store_scatter_simt_1(%src: ui64) {
266266
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
267267
%val = arith.constant dense<2.9>: vector<6xf32>
268268
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
269-
// expected-error@+1 {{Value shape [6] is not a valid distribution for tensor descriptor}}
269+
// expected-error@+1 {{Value shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
270270
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : vector<6xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
271271
return
272272
}

0 commit comments

Comments
 (0)