Skip to content

Commit 769bf19

Browse files
committed
add tests
1 parent 0b91942 commit 769bf19

File tree

4 files changed

+147
-43
lines changed

4 files changed

+147
-43
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -676,9 +676,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
676676
let hasVerifier = 1;
677677
}
678678

679-
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
680-
AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]>
681-
]> {
679+
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
682680
let summary = "load a set of scattered data points from memory.";
683681

684682
let description = [{ It (aka. load) load data per each work-item. The output
@@ -721,8 +719,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
721719
}];
722720

723721
let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
724-
Variadic<Index>: $offsets,
725-
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
722+
Optional<XeGPU_OffsetType>: $offsets,
726723
XeGPU_MaskType: $mask,
727724
OptionalAttr<I64Attr>: $chunk_size,
728725
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -760,11 +757,15 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
760757
}];
761758

762759
let assemblyFormat = [{
763-
$source ``
764-
custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
760+
$source
761+
(`[` $offsets^ `]`)? `,`
765762
$mask prop-dict
766-
attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
763+
attr-dict `:` type(operands) `->` type($value)
767764
}];
765+
766+
// functional-type(operands, results)
767+
// type($source) (type($offsets)^ )? `,` type($mask) `->` type($value)
768+
768769

769770
let builders = [
770771
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
@@ -776,9 +777,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
776777
let hasVerifier = 1;
777778
}
778779

779-
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
780-
AllElementTypesMatch<["value", "dest"]>, MemoryEffects<[MemWrite]>
781-
]> {
780+
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
782781
let summary = "store data to scattered memory locations.";
783782
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
784783
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
@@ -818,8 +817,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
818817
let arguments = (ins
819818
XeGPU_ValueType: $value,
820819
XeGPU_TensorDesc_or_MemRef: $dest,
821-
Variadic<Index>: $offsets,
822-
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
820+
Optional<XeGPU_OffsetType>: $offsets,
823821
XeGPU_MaskType: $mask,
824822
OptionalAttr<I64Attr>: $chunk_size,
825823
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -850,12 +848,13 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
850848

851849
let assemblyFormat = [{
852850
$value `,`
853-
$dest ``
854-
custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
851+
$dest
852+
(`[` $offsets^ `]`)? `,`
855853
$mask
856854
prop-dict
857-
attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask)
855+
attr-dict `:` type(operands)
858856
}];
857+
// type($value) `,` qualified(type($dest)) (type($offsets)^)? `,` type($mask)
859858

860859
let builders = [
861860
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,66 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
110110
return success();
111111
}
112112

113+
static LogicalResult
114+
isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy,
115+
MemRefType memTy, int64_t chunkSize,
116+
function_ref<InFlightDiagnostic()> emitError) {
117+
118+
if (!valueTy)
119+
return emitError() << "Expecting a vector type result.";
120+
121+
auto maskShape = getShapeOf(maskTy);
122+
auto valueShape = getShapeOf(valueTy);
123+
auto memShape = getShapeOf(memTy);
124+
125+
if (valueTy.getElementType() != memTy.getElementType())
126+
return emitError() << "Value should have the same element type as MemRef.";
127+
128+
// a valid shape for SIMT case
129+
if (valueTy.getRank() == 1) {
130+
if (valueTy.getNumElements() != chunkSize)
131+
return emitError() << "value elements must match chunk size " << chunkSize
132+
<< " for SIMT code.";
133+
return success();
134+
}
135+
136+
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
137+
if (chunkSize > 1)
138+
expectedMaskShape.pop_back();
139+
if (expectedMaskShape != maskShape)
140+
return emitError() << "Mask should match value except the chunk size dim.";
141+
142+
return success();
143+
}
144+
145+
static LogicalResult
146+
isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy,
147+
int64_t chunkSize,
148+
function_ref<InFlightDiagnostic()> emitError) {
149+
150+
if (!valueTy)
151+
return emitError() << "Expecting a vector type result.";
152+
153+
auto maskShape = getShapeOf(maskTy);
154+
auto valueShape = getShapeOf(valueTy);
155+
156+
// a valid shape for SIMT case
157+
if (valueTy.getRank() == 1) {
158+
if (valueTy.getNumElements() != chunkSize)
159+
return emitError() << "value elements must match chunk size " << chunkSize
160+
<< " for SIMT code.";
161+
return success();
162+
}
163+
164+
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
165+
if (chunkSize > 1)
166+
expectedMaskShape.pop_back();
167+
if (expectedMaskShape != maskShape)
168+
return emitError() << "Mask should match value except the chunk size dim.";
169+
170+
return success();
171+
}
172+
113173
//===----------------------------------------------------------------------===//
114174
// XeGPU_CreateNdDescOp
115175
//===----------------------------------------------------------------------===//
@@ -683,17 +743,27 @@ LogicalResult LoadGatherOp::verify() {
683743
if (!isReadHintOrNone(getL3HintAttr()))
684744
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
685745

686-
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
687-
[&]() { return emitOpError(); });
746+
if (tdescTy)
747+
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
748+
[&]() { return emitOpError(); });
749+
auto srcTy = getSourceType();
750+
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
751+
auto memTy = dyn_cast<MemRefType>(srcTy);
752+
753+
if (memTy)
754+
return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
755+
[&]() { return emitOpError(); });
756+
return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
757+
[&]() { return emitOpError(); });
688758
}
689759

690760
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
691761
Type valueType, Value source, Value mask,
692762
xegpu::CachePolicyAttr l1_hint,
693763
xegpu::CachePolicyAttr l2_hint,
694764
xegpu::CachePolicyAttr l3_hint) {
695-
build(builder, state, valueType, source, ValueRange(), DenseI64ArrayAttr(),
696-
mask, IntegerAttr(), l1_hint, l2_hint, l3_hint);
765+
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
766+
l1_hint, l2_hint, l3_hint);
697767
}
698768

699769
//===----------------------------------------------------------------------===//
@@ -713,17 +783,28 @@ LogicalResult StoreScatterOp::verify() {
713783
if (!isWriteHintOrNone(getL3HintAttr()))
714784
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
715785

716-
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
717-
[&]() { return emitOpError(); });
786+
if (tdescTy)
787+
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
788+
[&]() { return emitOpError(); });
789+
790+
auto destTy = getDestType();
791+
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
792+
auto memTy = dyn_cast<MemRefType>(destTy);
793+
794+
if (memTy)
795+
return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
796+
[&]() { return emitOpError(); });
797+
return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
798+
[&]() { return emitOpError(); });
718799
}
719800

720801
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
721802
Value value, Value dest, Value mask,
722803
xegpu::CachePolicyAttr l1_hint,
723804
xegpu::CachePolicyAttr l2_hint,
724805
xegpu::CachePolicyAttr l3_hint) {
725-
build(builder, state, value, dest, ValueRange(), DenseI64ArrayAttr(), mask,
726-
IntegerAttr(), l1_hint, l2_hint, l3_hint);
806+
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
807+
l2_hint, l3_hint);
727808
}
728809

729810
//===----------------------------------------------------------------------===//

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,28 @@ func.func @load_gather_vc_3(%src: ui64) {
384384
return
385385
}
386386

387+
// -----
388+
func.func @load_offset(%src: ui64) {
389+
%offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
390+
%mask = arith.constant dense<1>: vector<8xi1>
391+
// expected-error@+1 {{Mask should match value except the chunk size dim}}
392+
%2 = xegpu.load %src[%offsets], %mask
393+
: ui64, vector<4xindex>, vector<8xi1>
394+
-> vector<4x2xf32>
395+
return
396+
}
397+
398+
// -----
399+
func.func @store_offset(%src: ui64) {
400+
%val = arith.constant dense<2.9>: vector<4x2xf16>
401+
%offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
402+
%mask = arith.constant dense<1>: vector<8xi1>
403+
// expected-error@+1 {{Mask should match value except the chunk size dim}}
404+
xegpu.store %val, %src[%offsets], %mask
405+
: vector<4x2xf16>, ui64, vector<4xindex>, vector<8xi1>
406+
return
407+
}
408+
387409
// -----
388410
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
389411
%0 = arith.constant dense<1>: vector<4xi1>

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,6 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index)
130130
gpu.return
131131
}
132132

133-
// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
134-
gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
135-
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
136-
%1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
137-
// CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
138-
xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
139-
gpu.return
140-
}
141-
142133
// CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
143134
gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
144135
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -339,16 +330,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
339330
gpu.return
340331
}
341332

342-
// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<24x32xf16>) {
343-
gpu.func @subgroup_store_nd_3(%dst: memref<24x32xf16>) {
344-
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
345-
%1 = arith.constant dense<1.0>: vector<32xf16>
346-
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
347-
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
348-
// CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
349-
xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
350-
gpu.return
351-
}
352333

353334
// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
354335
gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
@@ -541,6 +522,16 @@ gpu.func @subgroup_load_4(%src: ui64) {
541522
gpu.return
542523
}
543524

525+
// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) {
526+
gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) {
527+
%offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
528+
%mask = arith.constant dense<1>: vector<4xi1>
529+
//CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
530+
%val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
531+
: memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
532+
gpu.return
533+
}
534+
544535
// CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
545536
gpu.func @subgroup_store(%src: ui64) {
546537
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -646,6 +637,17 @@ gpu.func @subgroup_store_4(%src: ui64) {
646637
gpu.return
647638
}
648639

640+
// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) {
641+
gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) {
642+
%val = arith.constant dense<2.9>: vector<4x2xf16>
643+
%offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
644+
%mask = arith.constant dense<1>: vector<4xi1>
645+
//CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
646+
xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
647+
: vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
648+
gpu.return
649+
}
650+
649651
// CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
650652
gpu.func @prefetch(%src: ui64) {
651653
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

0 commit comments

Comments
 (0)