@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
29
29
void printProperties(::mlir::MLIRContext *ctx,
30
30
::mlir::OpAsmPrinter &p, const Properties &prop,
31
31
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
32
-
32
+
33
33
DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
34
34
35
35
// filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
43
43
}
44
44
45
45
if (!filteredAttrs.empty()) {
46
- p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
46
+ p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
47
47
}
48
48
}
49
49
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
60
60
}
61
61
62
62
63
- def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
64
- AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
63
+ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
65
64
66
65
let summary = "Create nd-tensor descriptor operation";
67
66
let description = [{
@@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
181
180
return getType().getShape();
182
181
}
183
182
184
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
185
- OperandRange getSizes() {
186
- return getShape();
183
+ SmallVector<OpFoldResult> getMixedOffsets() {
184
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
185
+ auto dynamics = getOffsets();
186
+ if (statics.size() == 0 && dynamics.size() == 0)
187
+ return {};
188
+ return getMixedValues(statics, dynamics, getContext());
187
189
}
188
190
189
- ArrayRef<int64_t> getStaticOffsets(){
190
- auto attr = getConstOffsetsAttr();
191
-
192
- if (attr)
193
- return attr;
191
+ SmallVector<OpFoldResult> getMixedSizes() {
192
+ SmallVector<int64_t> statics;
194
193
195
- int64_t rank = getMixedSizes().size();
196
-
197
- setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
194
+ /// Get the static sizes/shape, the value passed to const_shape
195
+ /// will overide the value in memref shape.
196
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
197
+ statics = llvm::to_vector(memrefTy.getShape());
198
+ if (auto attr = getConstShapeAttr())
199
+ statics = llvm::to_vector(attr.asArrayRef());
198
200
199
- attr = getConstOffsetsAttr();
200
- return attr;
201
+ return getMixedValues(statics, getShape(), getContext());
201
202
}
202
203
203
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
204
- /// If source is IntegerType or `const_shape` is filled,
205
- /// it will return `const_shape`, such that mixes of `shape`
206
- /// and `const_shape` will be used to represent the shape of
207
- /// source operand. They overide static shape from source memref type.
208
- ArrayRef<int64_t> getStaticSizes() {
209
- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
210
- static llvm::SmallVector<int64_t, 4> emptyShape;
211
-
212
- auto attr = getConstShapeAttr();
213
- if (attr)
214
- return attr;
215
-
216
- if (llvm::isa<IntegerType>(getSourceType()))
217
- return emptyShape;
218
-
219
- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
220
- assert(memrefType && "Incorrect use of getStaticSizes");
221
- return memrefType.getShape();
222
- }
204
+ SmallVector<OpFoldResult> getMixedStrides() {
205
+ SmallVector<int64_t> statics;
223
206
224
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
225
- /// If source is IntegerType or `const_strides` is filled, it
226
- /// will return `const_strides`, such that mixes of `strides`
227
- /// and `const_strides` will be used to represent the strides of
228
- /// source operand. They overide static strides from source memref type.
229
- ArrayRef<int64_t> getStaticStrides() {
230
- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
231
- static llvm::SmallVector<int64_t, 4> emptyStrides;
232
-
233
- auto attr = getConstStridesAttr();
234
- if (attr)
235
- return attr;
236
-
237
- if (llvm::isa<IntegerType>(getSourceType()))
238
- return emptyStrides;
239
-
240
- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
241
- assert(memrefType && "Incorrect use of getStaticStrides");
242
- auto [strides, _] = memrefType.getStridesAndOffset();
243
- // reuse the storage of ConstStridesAttr since strides from
244
- // memref is not persistant
245
- setConstStrides(strides);
246
- attr = getConstStridesAttr();
247
- return attr;
248
- }
207
+ /// Get the static strides, the value passed to const_strides
208
+ /// will overide the value in memref.
209
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
210
+ statics = memrefTy.getStridesAndOffset().first;
211
+ if (auto attr = getConstStridesAttr())
212
+ statics = llvm::to_vector(attr.asArrayRef());
249
213
250
- /// Return the expected rank of each of the`static_offsets`,
251
- /// `static_shape` and `static_strides` attributes.
252
- std::array<unsigned, 3> getArrayAttrMaxRanks() {
253
- unsigned rank;
254
- if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
255
- rank = ty.getRank();
256
- } else {
257
- rank = (unsigned)getMixedOffsets().size();
258
- }
259
- return {rank, rank, rank};
214
+ return getMixedValues(statics, getStrides(), getContext());
260
215
}
261
216
262
217
/// Return the number of leading operands before the `offsets`,
@@ -314,15 +269,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
314
269
}];
315
270
316
271
let assemblyFormat = [{
317
- $TensorDesc ``
318
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
272
+ $TensorDesc ``
273
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
319
274
prop-dict attr-dict `:` qualified(type($TensorDesc))
320
275
}];
321
276
322
277
let builders = [
323
- OpBuilder<(ins "Value": $TensorDesc,
324
- "xegpu::CachePolicyAttr": $l1_hint,
325
- "xegpu::CachePolicyAttr": $l2_hint,
278
+ OpBuilder<(ins "Value": $TensorDesc,
279
+ "xegpu::CachePolicyAttr": $l1_hint,
280
+ "xegpu::CachePolicyAttr": $l2_hint,
326
281
"xegpu::CachePolicyAttr": $l3_hint)>
327
282
];
328
283
@@ -370,7 +325,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
370
325
371
326
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
372
327
Variadic<Index>: $offsets,
373
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
328
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
374
329
OptionalAttr<UnitAttr>: $packed,
375
330
OptionalAttr<DenseI64ArrayAttr>: $transpose,
376
331
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +345,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
390
345
}];
391
346
392
347
let assemblyFormat = [{
393
- $TensorDesc ``
394
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
348
+ $TensorDesc ``
349
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
395
350
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
396
351
}];
397
352
398
353
let builders = [
399
- OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
354
+ OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
400
355
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
401
- "xegpu::CachePolicyAttr": $l1_hint,
402
- "xegpu::CachePolicyAttr": $l2_hint,
356
+ "xegpu::CachePolicyAttr": $l1_hint,
357
+ "xegpu::CachePolicyAttr": $l2_hint,
403
358
"xegpu::CachePolicyAttr": $l3_hint)>
404
359
];
405
360
@@ -442,7 +397,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
442
397
let arguments = (ins XeGPU_ValueType: $value,
443
398
XeGPU_TensorDesc: $TensorDesc,
444
399
Variadic<Index>: $offsets,
445
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
400
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
446
401
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
447
402
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
448
403
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +413,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
458
413
}];
459
414
460
415
let assemblyFormat = [{
461
- $value `,`
462
- $TensorDesc ``
463
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
416
+ $value `,`
417
+ $TensorDesc ``
418
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
464
419
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
465
420
}];
466
421
467
422
let builders = [
468
- OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
469
- "xegpu::CachePolicyAttr": $l1_hint,
470
- "xegpu::CachePolicyAttr": $l2_hint,
423
+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
424
+ "xegpu::CachePolicyAttr": $l1_hint,
425
+ "xegpu::CachePolicyAttr": $l2_hint,
471
426
"xegpu::CachePolicyAttr": $l3_hint)>
472
427
];
473
428
@@ -635,12 +590,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
635
590
l3_hint = #xegpu.cache_hint<cached>}
636
591
: !xegpu.tensor_desc<16xf16>
637
592
```
638
-
593
+
639
594
Example 2:
640
595
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641
596
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642
597
The source operand could be a raw pointer (uint64_t).
643
- Please refer to create_tdesc for the restriction of memref.
598
+ Please refer to create_tdesc for the restriction of memref.
644
599
```mlir
645
600
%a = memref.alloc() : memref<1024xf32>
646
601
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +631,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
676
631
}];
677
632
678
633
let assemblyFormat = [{
679
- $source
634
+ $source
680
635
(`[` $offsets^ `]`)?
681
636
prop-dict
682
- attr-dict `:` type(operands)
637
+ attr-dict `:` type(operands)
683
638
}];
684
-
639
+
685
640
let builders = [
686
641
OpBuilder<(ins "Value": $source,
687
- "xegpu::CachePolicyAttr": $l1_hint,
688
- "xegpu::CachePolicyAttr": $l2_hint,
642
+ "xegpu::CachePolicyAttr": $l1_hint,
643
+ "xegpu::CachePolicyAttr": $l2_hint,
689
644
"xegpu::CachePolicyAttr": $l3_hint)>
690
645
];
691
646
@@ -723,7 +678,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
723
678
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
724
679
vector<16xi1> -> vector<16x8xf32>
725
680
```
726
-
681
+
727
682
Example 3 (SIMT mode):
728
683
```mlir
729
684
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +687,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
732
687
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
733
688
vector<16xi1> -> vector<8xf32>
734
689
```
735
-
690
+
736
691
Example 4:
737
692
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738
693
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739
694
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740
- for the restriction of memref.
695
+ for the restriction of memref.
741
696
```mlir
742
697
%a = memref.alloc() : memref<1024xf32>
743
698
%offsets = vector.step : vector<16xindex>
@@ -794,14 +749,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
794
749
let assemblyFormat = [{
795
750
$source
796
751
(`[` $offsets^ `]`)? `,`
797
- $mask prop-dict
752
+ $mask prop-dict
798
753
attr-dict `:` type(operands) `->` type($value)
799
754
}];
800
755
801
756
let builders = [
802
757
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
803
- "xegpu::CachePolicyAttr": $l1_hint,
804
- "xegpu::CachePolicyAttr": $l2_hint,
758
+ "xegpu::CachePolicyAttr": $l1_hint,
759
+ "xegpu::CachePolicyAttr": $l2_hint,
805
760
"xegpu::CachePolicyAttr": $l3_hint)>
806
761
];
807
762
@@ -848,7 +803,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
848
803
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849
804
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850
805
The dest operand could be a raw pointer (uint64_t).
851
- Please refer to create_tdesc for the restriction of memref.
806
+ Please refer to create_tdesc for the restriction of memref.
852
807
```mlir
853
808
%a = memref.alloc() : memref<1024xf32>
854
809
%val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +856,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
901
856
$value `,`
902
857
$dest
903
858
(`[` $offsets^ `]`)? `,`
904
- $mask
905
- prop-dict
859
+ $mask
860
+ prop-dict
906
861
attr-dict `:` type(operands)
907
862
}];
908
863
909
864
let builders = [
910
865
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
911
- "xegpu::CachePolicyAttr": $l1_hint,
912
- "xegpu::CachePolicyAttr": $l2_hint,
866
+ "xegpu::CachePolicyAttr": $l1_hint,
867
+ "xegpu::CachePolicyAttr": $l2_hint,
913
868
"xegpu::CachePolicyAttr": $l3_hint)>
914
869
];
915
870
0 commit comments