@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
2929 void printProperties(::mlir::MLIRContext *ctx,
3030 ::mlir::OpAsmPrinter &p, const Properties &prop,
3131 ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
32-
32+
3333 DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
3434
3535 // filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
4343 }
4444
4545 if (!filteredAttrs.empty()) {
46- p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
46+ p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
4747 }
4848 }
4949
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
6060}
6161
6262
63- def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
64- AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
63+ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
6564
6665 let summary = "Create nd-tensor descriptor operation";
6766 let description = [{
@@ -181,82 +180,40 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
181180 return getType().getShape();
182181 }
183182
184- /// wrapper for matching with OffsetSizeAndStrideOpInterface
185- OperandRange getSizes() {
186- return getShape();
183+ SmallVector<OpFoldResult> getMixedOffsets() {
184+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
185+ auto dynamics = getOffsets();
186+ if (statics.size() == 0 && dynamics.size() == 0)
187+ return {};
188+ return getMixedValues(statics, dynamics, getContext());
187189 }
188190
189- ArrayRef<int64_t> getStaticOffsets(){
190- auto attr = getConstOffsetsAttr();
191-
192- if (attr)
193- return attr;
191+ SmallVector<OpFoldResult> getMixedSizes() {
192+ Builder b(getContext());
193+ SmallVector<int64_t> statics;
194194
195- int64_t rank = getMixedSizes().size();
196-
197- setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
195+ /// Get the static sizes/shape, the value passed to const_shape
196+ /// will overide the value in memref shape.
197+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
198+ statics = llvm::to_vector(memrefTy.getShape());
199+ if (auto attr = getConstShapeAttr())
200+ statics = llvm::to_vector(attr.asArrayRef());
198201
199- attr = getConstOffsetsAttr();
200- return attr;
202+ return getMixedValues(statics, getShape(), b);
201203 }
202204
203- /// wrapper for matching with OffsetSizeAndStrideOpInterface
204- /// If source is IntegerType or `const_shape` is filled,
205- /// it will return `const_shape`, such that mixes of `shape`
206- /// and `const_shape` will be used to represent the shape of
207- /// source operand. They overide static shape from source memref type.
208- ArrayRef<int64_t> getStaticSizes() {
209- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
210- static llvm::SmallVector<int64_t, 4> emptyShape;
211-
212- auto attr = getConstShapeAttr();
213- if (attr)
214- return attr;
215-
216- if (llvm::isa<IntegerType>(getSourceType()))
217- return emptyShape;
218-
219- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
220- assert(memrefType && "Incorrect use of getStaticSizes");
221- return memrefType.getShape();
222- }
205+ SmallVector<OpFoldResult> getMixedStrides() {
206+ Builder b(getContext());
207+ SmallVector<int64_t> statics;
223208
224- /// wrapper for matching with OffsetSizeAndStrideOpInterface
225- /// If source is IntegerType or `const_strides` is filled, it
226- /// will return `const_strides`, such that mixes of `strides`
227- /// and `const_strides` will be used to represent the strides of
228- /// source operand. They overide static strides from source memref type.
229- ArrayRef<int64_t> getStaticStrides() {
230- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
231- static llvm::SmallVector<int64_t, 4> emptyStrides;
232-
233- auto attr = getConstStridesAttr();
234- if (attr)
235- return attr;
236-
237- if (llvm::isa<IntegerType>(getSourceType()))
238- return emptyStrides;
239-
240- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
241- assert(memrefType && "Incorrect use of getStaticStrides");
242- auto [strides, _] = memrefType.getStridesAndOffset();
243- // reuse the storage of ConstStridesAttr since strides from
244- // memref is not persistant
245- setConstStrides(strides);
246- attr = getConstStridesAttr();
247- return attr;
248- }
209+ /// Get the static strides, the value passed to const_strides
210+ /// will overide the value in memref.
211+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
212+ statics = memrefTy.getStridesAndOffset().first;
213+ if (auto attr = getConstStridesAttr())
214+ statics = llvm::to_vector(attr.asArrayRef());
249215
250- /// Return the expected rank of each of the`static_offsets`,
251- /// `static_shape` and `static_strides` attributes.
252- std::array<unsigned, 3> getArrayAttrMaxRanks() {
253- unsigned rank;
254- if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
255- rank = ty.getRank();
256- } else {
257- rank = (unsigned)getMixedOffsets().size();
258- }
259- return {rank, rank, rank};
216+ return getMixedValues(statics, getStrides(), b);
260217 }
261218
262219 /// Return the number of leading operands before the `offsets`,
@@ -314,15 +271,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
314271 }];
315272
316273 let assemblyFormat = [{
317- $TensorDesc ``
318- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
274+ $TensorDesc ``
275+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
319276 prop-dict attr-dict `:` qualified(type($TensorDesc))
320277 }];
321278
322279 let builders = [
323- OpBuilder<(ins "Value": $TensorDesc,
324- "xegpu::CachePolicyAttr": $l1_hint,
325- "xegpu::CachePolicyAttr": $l2_hint,
280+ OpBuilder<(ins "Value": $TensorDesc,
281+ "xegpu::CachePolicyAttr": $l1_hint,
282+ "xegpu::CachePolicyAttr": $l2_hint,
326283 "xegpu::CachePolicyAttr": $l3_hint)>
327284 ];
328285
@@ -370,7 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
370327
371328 let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
372329 Variadic<Index>: $offsets,
373- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
330+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
374331 OptionalAttr<UnitAttr>: $packed,
375332 OptionalAttr<DenseI64ArrayAttr>: $transpose,
376333 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +347,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
390347 }];
391348
392349 let assemblyFormat = [{
393- $TensorDesc ``
394- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
350+ $TensorDesc ``
351+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
395352 prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
396353 }];
397354
398355 let builders = [
399- OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
356+ OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
400357 "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
401- "xegpu::CachePolicyAttr": $l1_hint,
402- "xegpu::CachePolicyAttr": $l2_hint,
358+ "xegpu::CachePolicyAttr": $l1_hint,
359+ "xegpu::CachePolicyAttr": $l2_hint,
403360 "xegpu::CachePolicyAttr": $l3_hint)>
404361 ];
405362
@@ -442,7 +399,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
442399 let arguments = (ins XeGPU_ValueType: $value,
443400 XeGPU_TensorDesc: $TensorDesc,
444401 Variadic<Index>: $offsets,
445- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
402+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
446403 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
447404 OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
448405 OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +415,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
458415 }];
459416
460417 let assemblyFormat = [{
461- $value `,`
462- $TensorDesc ``
463- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
418+ $value `,`
419+ $TensorDesc ``
420+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
464421 prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
465422 }];
466423
467424 let builders = [
468- OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
469- "xegpu::CachePolicyAttr": $l1_hint,
470- "xegpu::CachePolicyAttr": $l2_hint,
425+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
426+ "xegpu::CachePolicyAttr": $l1_hint,
427+ "xegpu::CachePolicyAttr": $l2_hint,
471428 "xegpu::CachePolicyAttr": $l3_hint)>
472429 ];
473430
@@ -635,12 +592,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
635592 l3_hint = #xegpu.cache_hint<cached>}
636593 : !xegpu.tensor_desc<16xf16>
637594 ```
638-
595+
639596 Example 2:
640597 A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641598 It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642599 The source operand could be a raw pointer (uint64_t).
643- Please refer to create_tdesc for the restriction of memref.
600+ Please refer to create_tdesc for the restriction of memref.
644601 ```mlir
645602 %a = memref.alloc() : memref<1024xf32>
646603 %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +633,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
676633 }];
677634
678635 let assemblyFormat = [{
679- $source
636+ $source
680637 (`[` $offsets^ `]`)?
681638 prop-dict
682- attr-dict `:` type(operands)
639+ attr-dict `:` type(operands)
683640 }];
684-
641+
685642 let builders = [
686643 OpBuilder<(ins "Value": $source,
687- "xegpu::CachePolicyAttr": $l1_hint,
688- "xegpu::CachePolicyAttr": $l2_hint,
644+ "xegpu::CachePolicyAttr": $l1_hint,
645+ "xegpu::CachePolicyAttr": $l2_hint,
689646 "xegpu::CachePolicyAttr": $l3_hint)>
690647 ];
691648
@@ -723,7 +680,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
723680 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
724681 vector<16xi1> -> vector<16x8xf32>
725682 ```
726-
683+
727684 Example 3 (SIMT mode):
728685 ```mlir
729686 %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +689,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
732689 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
733690 vector<16xi1> -> vector<8xf32>
734691 ```
735-
692+
736693 Example 4:
737694 A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738695 It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739696 The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740- for the restriction of memref.
697+ for the restriction of memref.
741698 ```mlir
742699 %a = memref.alloc() : memref<1024xf32>
743700 %offsets = vector.step : vector<16xindex>
@@ -794,14 +751,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
794751 let assemblyFormat = [{
795752 $source
796753 (`[` $offsets^ `]`)? `,`
797- $mask prop-dict
754+ $mask prop-dict
798755 attr-dict `:` type(operands) `->` type($value)
799756 }];
800757
801758 let builders = [
802759 OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
803- "xegpu::CachePolicyAttr": $l1_hint,
804- "xegpu::CachePolicyAttr": $l2_hint,
760+ "xegpu::CachePolicyAttr": $l1_hint,
761+ "xegpu::CachePolicyAttr": $l2_hint,
805762 "xegpu::CachePolicyAttr": $l3_hint)>
806763 ];
807764
@@ -848,7 +805,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
848805 A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849806 It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850807 The dest operand could be a raw pointer (uint64_t).
851- Please refer to create_tdesc for the restriction of memref.
808+ Please refer to create_tdesc for the restriction of memref.
852809 ```mlir
853810 %a = memref.alloc() : memref<1024xf32>
854811 %val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +858,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
901858 $value `,`
902859 $dest
903860 (`[` $offsets^ `]`)? `,`
904- $mask
905- prop-dict
861+ $mask
862+ prop-dict
906863 attr-dict `:` type(operands)
907864 }];
908865
909866 let builders = [
910867 OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
911- "xegpu::CachePolicyAttr": $l1_hint,
912- "xegpu::CachePolicyAttr": $l2_hint,
868+ "xegpu::CachePolicyAttr": $l1_hint,
869+ "xegpu::CachePolicyAttr": $l2_hint,
913870 "xegpu::CachePolicyAttr": $l3_hint)>
914871 ];
915872
0 commit comments