Skip to content

Commit 2159119

Browse files
committed
refine verfier for load_nd and store_nd
1 parent 2a1d373 commit 2159119

File tree

3 files changed

+43
-33
lines changed

3 files changed

+43
-33
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,9 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
840840
can be represented as `B: vector<8x16x2xf16>`.
841841

842842
In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result,
843-
which are represented as 1D vectors.
843+
which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
844+
(https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
845+
for more details about the fragment distribution.
844846

845847
Note: on PVC, the hardware can perform load with VNNI transformation when data
846848
element type is 16-bit or lower precision, taking 2 or 4 elements from

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -270,33 +270,31 @@ LogicalResult LoadNdOp::verify() {
270270
if (!isReadHintOrNone(getL3HintAttr()))
271271
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
272272

273-
// Handling a 1D vector as the result can be complex. It may represent the
274-
// outcome of a 1D block load in SIMD mode or a fragment of a block load
275-
// result in SIMT mode. In the latter case, the tensor descriptor must be
276-
// evenly distributed, with each lane holding an equally sized fragment of
277-
// the result. Only subgroup size 8 or 16 is supported.
278-
if (valueTy.getRank() == 1 &&
279-
valueTy.getNumElements() < tdescTy.getNumElements()) {
273+
int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
274+
int valueElems = valueTy.getNumElements();
275+
276+
// If the result vector is 1D and has less elements than the tensor
277+
// descriptor, it is supposed to be a SIMT op. The layout attribute in
278+
// tensor_desc is not needed.
279+
if (valueElems < tdescElems && valueTy.getRank() == 1) {
280280
// SIMT mode doesn't need LayoutAttr.
281281
if (tdescTy.getLayoutAttr())
282282
return emitOpError()
283283
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
284284

285-
int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
286-
int valueElems = valueTy.getNumElements();
287-
288-
int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1;
289-
if (lanes != 16 && lanes != 8) {
285+
// For SIMT code, the load is evenly distributed across all lanes in a
286+
// subgroup. Since subgroup size is arch dependent, we only check even
287+
// distribution here.
288+
if (tdescElems % valueElems)
290289
return emitOpError()
291290
<< "Result shape " << makeString(getShapeOf(valueTy))
292291
<< " is not a valid distribution for tensor descriptor "
293292
<< tdescTy;
294-
}
293+
295294
return success();
296295
}
297296

298297
// Check SIMD mode.
299-
auto array_len = tdescTy.getArrayLength();
300298
// adjusted tensor descriptor shape tracks the expected shape of the result.
301299
auto tdescShape = getShapeOf(tdescTy);
302300
auto valueShape = getShapeOf(valueTy);
@@ -328,6 +326,7 @@ LogicalResult LoadNdOp::verify() {
328326
}
329327
}
330328

329+
auto array_len = tdescTy.getArrayLength();
331330
if (array_len > 1) {
332331
tdescShape.insert(tdescShape.begin(), array_len);
333332
}
@@ -366,25 +365,23 @@ LogicalResult StoreNdOp::verify() {
366365
if (!isWriteHintOrNone(getL3HintAttr()))
367366
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
368367

369-
auto tdescShape = getShapeOf(dstTy);
370-
auto valueShape = getShapeOf(valTy);
368+
auto array_len = dstTy.getArrayLength();
369+
if (array_len > 1)
370+
return emitOpError("array length is not supported by store_nd.\n");
371+
372+
auto tdescElems = dstTy.getNumElements();
373+
auto valueElems = valTy.getNumElements();
371374

372-
// Similar to LoadNdOp, handling a 1D vector as the value can be complex. It
373-
// may represent the input of a 1D block store in SIMD mode or a fragment of
374-
// a block store input in SIMT mode. In the latter case, the tensor descriptor
375-
// must be evenly distributed, with each lane holding an equally sized
376-
// fragment of the input. Only subgroup size 8 or 16 is supported.
377-
if (valTy.getRank() == 1 && valTy.getNumElements() < dstTy.getNumElements()) {
375+
// Similar to LoadNdOp, if the value vector is 1D and has less elements than
376+
// the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
377+
// in tensor_desc is not needed.
378+
if (valTy.getRank() == 1 && valueElems < tdescElems) {
378379
// SIMT mode doesn't need LayoutAttr.
379380
if (dstTy.getLayoutAttr())
380381
return emitOpError()
381382
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
382383

383-
int tdescElems = dstTy.getNumElements() * dstTy.getArrayLength();
384-
int valueElems = valueShape[0];
385-
386-
int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1;
387-
if (lanes != 16 && lanes != 8) {
384+
if (tdescElems % valueElems) {
388385
return emitOpError()
389386
<< "Value shape " << makeString(getShapeOf(valTy))
390387
<< " is not a valid distribution for tensor descriptor " << dstTy;
@@ -393,6 +390,8 @@ LogicalResult StoreNdOp::verify() {
393390
}
394391

395392
// SIMD code should have the same shape as the tensor descriptor.
393+
auto tdescShape = getShapeOf(dstTy);
394+
auto valueShape = getShapeOf(valTy);
396395
if (tdescShape != valueShape) {
397396
return emitOpError() << "Value shape " << makeString(valueShape)
398397
<< " is not consistent with tensor descriptor "

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
8080
// -----
8181
func.func @test_load_nd_layout(%src: memref<24x32xf32>) {
8282
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
83-
// expected-error@+1 {{Result shape [8] is not a valid distribution for tensor descriptor}}
83+
// expected-error@+1 {{Result shape [3] is not a valid distribution for tensor descriptor}}
8484
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
85-
l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf32> -> vector<8xf32>
85+
l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf32> -> vector<3xf32>
8686
return
8787
}
8888

@@ -119,10 +119,19 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
119119
}
120120

121121
// -----
122-
func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<4xf32>) {
122+
func.func @test_store_nd_vc_3(%dst: memref<24x32xf16>) {
123+
%1 = arith.constant dense<1.0>: vector<2x24x32xf16>
124+
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<array_length = 2>>
125+
// expected-error@+1 {{array length is not supported by store_nd}}
126+
xegpu.store_nd %1, %2: vector<2x24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<array_length = 2>>
127+
return
128+
}
129+
130+
// -----
131+
func.func @test_store_nd_simt(%dst: memref<24x32xf32>, %data: vector<3xf32>) {
123132
%1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
124-
// expected-error@+1 {{Value shape [4] is not a valid distribution for tensor descriptor}}
125-
xegpu.store_nd %data, %1 : vector<4xf32>, !xegpu.tensor_desc<16xf32>
133+
// expected-error@+1 {{Value shape [3] is not a valid distribution for tensor descriptor}}
134+
xegpu.store_nd %data, %1 : vector<3xf32>, !xegpu.tensor_desc<16xf32>
126135
return
127136
}
128137

0 commit comments

Comments
 (0)