Skip to content

Commit 7072bc1

Browse files
committed
refator verifiers for load_gather, store_scatter and dpas
1 parent 5520ce1 commit 7072bc1

File tree

2 files changed

+66
-89
lines changed

2 files changed

+66
-89
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 57 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,48 @@ static bool isEvenDistributed(llvm::ArrayRef<int64_t> shape,
101101
return true;
102102
}
103103

104+
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref<InFlightDiagnostic()> emitError) {
105+
106+
if (!tdescTy.isScattered())
107+
return emitError() << "Expects a scattered TensorDesc.";
108+
109+
if (!valueTy)
110+
return emitError() << "Expecting a vector type result.";
111+
112+
auto maskShape = getShapeOf(maskTy);
113+
auto valueShape = getShapeOf(valueTy);
114+
auto tdescShape = getShapeOf(tdescTy);
115+
auto chunkSize = tdescTy.getChunkSize();
116+
117+
if (valueTy.getElementType() != tdescTy.getElementType())
118+
return emitError() << "Value should have the same element type as TensorDesc.";
119+
120+
if (tdescShape[0] != maskShape[0])
121+
return emitError() << "dim-0 of the Mask and TensorDesc should be the same.";
122+
123+
// a valid shape for SIMT case
124+
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
125+
if (tdescTy.getLayoutAttr())
126+
return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
127+
if (transposeAttr)
128+
return emitError() << "doesn't need TransposeAttr for SIMT code";
129+
return success();
130+
}
131+
132+
if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
133+
if (!transposeAttr)
134+
return emitError() << "rank-2 tensor has to be transposed.";
135+
transpose({1, 0}, tdescShape);
136+
}
137+
138+
if (tdescShape != valueShape)
139+
return emitError() << "Value shape " << makeString(valueShape)
140+
<< " is neither a valid distribution for SIMT nor "
141+
"consistent with the tensor descriptor for SIMD "
142+
<< tdescTy;
143+
return success();
144+
}
145+
104146
//===----------------------------------------------------------------------===//
105147
// XeGPU_CreateNdDescOp
106148
//===----------------------------------------------------------------------===//
@@ -517,12 +559,6 @@ LogicalResult LoadGatherOp::verify() {
517559
auto maskTy = getMaskType();
518560
auto valueTy = getValueType();
519561

520-
if (!valueTy)
521-
return emitOpError("Expecting a vector type result.\n");
522-
523-
if (!tdescTy.isScattered())
524-
return emitOpError("Expects a scattered TensorDesc.\n");
525-
526562
if (!isReadHintOrNone(getL1HintAttr()))
527563
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
528564

@@ -532,52 +568,17 @@ LogicalResult LoadGatherOp::verify() {
532568
if (!isReadHintOrNone(getL3HintAttr()))
533569
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
534570

535-
auto tdescElemTy = tdescTy.getElementType();
536-
auto valueElemTy = getElementType();
537-
if (tdescElemTy != valueElemTy)
538-
return emitOpError(
539-
"Value should have the same element type as TensorDesc.");
540-
541-
auto maskShape = getShapeOf(maskTy);
542-
auto valueShape = getShapeOf(valueTy);
543-
auto tdescShape = getShapeOf(tdescTy);
544-
545-
if (tdescShape[0] != maskShape[0])
546-
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
547-
548-
auto chunkSize = tdescTy.getChunkSize();
549-
550-
// a valid shape for SIMT case
551-
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
552-
if (tdescTy.getLayoutAttr())
553-
return emitOpError()
554-
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
555-
if (getTransposeAttr())
556-
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
557-
return success();
558-
}
559-
560-
if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
561-
if (!getTransposeAttr())
562-
return emitOpError("load of rank-2 tensor has to be transposed.");
563-
transpose({1, 0}, tdescShape);
564-
}
565-
566-
if (tdescShape != valueShape)
567-
return emitOpError() << "Result shape " << makeString(valueShape)
568-
<< " is neither a valid distribution for SIMT nor "
569-
"consistent with the tensor descriptor for SIMD "
570-
<< tdescTy;
571-
return success();
571+
return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(),
572+
[&]() { return emitOpError(); });
572573
}
573574

574575
//===----------------------------------------------------------------------===//
575576
// XeGPU_StoreScatterOp
576577
//===----------------------------------------------------------------------===//
577578
LogicalResult StoreScatterOp::verify() {
578579
auto tdescTy = getTensorDescType();
579-
if (!tdescTy.isScattered())
580-
return emitOpError("Expects a scattered TensorDesc.\n");
580+
auto maskTy = getMaskType();
581+
auto valueTy = getValueType();
581582

582583
if (!isWriteHintOrNone(getL1HintAttr()))
583584
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -588,43 +589,8 @@ LogicalResult StoreScatterOp::verify() {
588589
if (!isWriteHintOrNone(getL3HintAttr()))
589590
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
590591

591-
auto maskTy = getMaskType();
592-
auto valueTy = getValueType();
593-
594-
if (!valueTy)
595-
return emitOpError("Expecting a vector type for the value.\n");
596-
597-
auto maskShape = getShapeOf(maskTy);
598-
auto tdescShape = getShapeOf(tdescTy);
599-
auto valueShape = getShapeOf(valueTy);
600-
if (tdescShape[0] != maskShape[0])
601-
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
602-
603-
auto chunkSize = tdescTy.getChunkSize();
604-
605-
// a valid shape for SIMT case
606-
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
607-
if (tdescTy.getLayoutAttr())
608-
return emitOpError()
609-
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
610-
if (getTransposeAttr())
611-
return emitOpError() << "doesn't need TransposeAttr for SIMT code";
612-
return success();
613-
}
614-
615-
if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
616-
if (!getTransposeAttr())
617-
return emitOpError("Store of a rank-2 tensor has to be transposed.");
618-
transpose({1, 0}, tdescShape);
619-
}
620-
621-
if (tdescShape != valueShape)
622-
return emitOpError() << "Value shape " << makeString(valueShape)
623-
<< " is neither a valid distribution for SIMT nor "
624-
"consistent with the tensor descriptor for SIMD "
625-
<< tdescTy;
626-
627-
return success();
592+
return isValidGatherScatterParams(maskTy, valueTy, tdescTy, getTransposeAttr(),
593+
[&]() { return emitOpError(); });
628594
}
629595

630596
//===----------------------------------------------------------------------===//
@@ -660,14 +626,18 @@ LogicalResult DpasOp::verify() {
660626
auto rhsShape = getRhsType().getShape();
661627
auto resShape = getResultType().getShape();
662628

663-
if (getAcc()) {
664-
if (getAcc().getType() != getResultType())
665-
return emitOpError("Expecting the acc type to be the same as result.");
666-
}
629+
if (getAcc() && getAcc().getType() != getResultType())
630+
return emitOpError("Expecting the acc type to be the same as result.");
667631

668-
// SIMT code: skip the check since lack of semantic info at this level.
632+
// SIMT code: the size of the B operand has to be a multiple of 32 bits.
633+
// It skips the semantic check since lack of architecture information.
669634
// Users need to ensure the correctness.
670635
if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
636+
auto numElems = getRhsType().getNumElements();
637+
auto elemTy = getRhsType().getElementType();
638+
auto factor = 32 / elemTy.getIntOrFloatBitWidth();
639+
if (numElems % factor != 0)
640+
return emitOpError("Expecting B operand to be a multiple of 32 bits.");
671641
return success();
672642
} else { // SIMD code
673643
if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func.func @test_load_gather_simt_1(%src: ui64) {
255255
%0 = arith.constant dense<1>: vector<4xi1>
256256
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
257257
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
258-
// expected-error@+1 {{Result shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
258+
// expected-error@+1 {{Value shape [6] is neither a valid distribution for SIMT nor consistent with the tensor descriptor for SIMD}}
259259
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<6xf32>
260260
return
261261
}
@@ -347,12 +347,19 @@ func.func @test_dpas_4(%a : vector<16x16xf16>, %b: vector<8x16x2xf16>) {
347347
}
348348

349349
// -----
350-
func.func @test_dpas_4(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
350+
func.func @test_dpas_5(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
351351
// expected-error@+1 {{N-dimension mismatch}}
352352
%1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<8x8x2xf16> -> vector<8x16xf32>
353353
return
354354
}
355355

356+
// -----
357+
func.func @test_dpas_simt_1(%a : vector<8xf16>, %b: vector<15xf16>) {
358+
// expected-error@+1 {{Expecting B operand to be a multiple of 32 bits}}
359+
%1 = xegpu.dpas %a, %b : vector<8xf16>, vector<15xf16> -> vector<8xf32>
360+
return
361+
}
362+
356363
// -----
357364
func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) {
358365
%0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>

0 commit comments

Comments
 (0)