@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
2929 void printProperties(::mlir::MLIRContext *ctx,
3030 ::mlir::OpAsmPrinter &p, const Properties &prop,
3131 ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
32-
32+
3333 DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
3434
3535 // filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
4343 }
4444
4545 if (!filteredAttrs.empty()) {
46- p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
46+ p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
4747 }
4848 }
4949
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
6060}
6161
6262
63- def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
64- AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
63+ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
6564
6665 let summary = "Create nd-tensor descriptor operation";
6766 let description = [{
@@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
181180 return getType().getShape();
182181 }
183182
184- /// wrapper for matching with OffsetSizeAndStrideOpInterface
185- OperandRange getSizes() {
186- return getShape();
183+ SmallVector<OpFoldResult> getMixedOffsets() {
184+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
185+ auto dynamics = getOffsets();
186+ if (statics.size() == 0 && dynamics.size() == 0)
187+ return {};
188+ return getMixedValues(statics, dynamics, getContext());
187189 }
188190
189- ArrayRef<int64_t> getStaticOffsets(){
190- auto attr = getConstOffsetsAttr();
191-
192- if (attr)
193- return attr;
191+ SmallVector<OpFoldResult> getMixedSizes() {
192+ SmallVector<int64_t> statics;
194193
195- int64_t rank = getMixedSizes().size();
196-
197- setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
194+ /// Get the static sizes/shape, the value passed to const_shape
195+ /// will overide the value in memref shape.
196+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
197+ statics = llvm::to_vector(memrefTy.getShape());
198+ if (auto attr = getConstShapeAttr())
199+ statics = llvm::to_vector(attr.asArrayRef());
198200
199- attr = getConstOffsetsAttr();
200- return attr;
201+ return getMixedValues(statics, getShape(), getContext());
201202 }
202203
203- /// wrapper for matching with OffsetSizeAndStrideOpInterface
204- /// If source is IntegerType or `const_shape` is filled,
205- /// it will return `const_shape`, such that mixes of `shape`
206- /// and `const_shape` will be used to represent the shape of
207- /// source operand. They overide static shape from source memref type.
208- ArrayRef<int64_t> getStaticSizes() {
209- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
210- static llvm::SmallVector<int64_t, 4> emptyShape;
211-
212- auto attr = getConstShapeAttr();
213- if (attr)
214- return attr;
215-
216- if (llvm::isa<IntegerType>(getSourceType()))
217- return emptyShape;
218-
219- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
220- assert(memrefType && "Incorrect use of getStaticSizes");
221- return memrefType.getShape();
222- }
204+ SmallVector<OpFoldResult> getMixedStrides() {
205+ SmallVector<int64_t> statics;
223206
224- /// wrapper for matching with OffsetSizeAndStrideOpInterface
225- /// If source is IntegerType or `const_strides` is filled, it
226- /// will return `const_strides`, such that mixes of `strides`
227- /// and `const_strides` will be used to represent the strides of
228- /// source operand. They overide static strides from source memref type.
229- ArrayRef<int64_t> getStaticStrides() {
230- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
231- static llvm::SmallVector<int64_t, 4> emptyStrides;
232-
233- auto attr = getConstStridesAttr();
234- if (attr)
235- return attr;
236-
237- if (llvm::isa<IntegerType>(getSourceType()))
238- return emptyStrides;
239-
240- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
241- assert(memrefType && "Incorrect use of getStaticStrides");
242- auto [strides, _] = memrefType.getStridesAndOffset();
243- // reuse the storage of ConstStridesAttr since strides from
244- // memref is not persistant
245- setConstStrides(strides);
246- attr = getConstStridesAttr();
247- return attr;
248- }
207+ /// Get the static strides, the value passed to const_strides
208+ /// will overide the value in memref.
209+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
210+ statics = memrefTy.getStridesAndOffset().first;
211+ if (auto attr = getConstStridesAttr())
212+ statics = llvm::to_vector(attr.asArrayRef());
249213
250- /// Return the expected rank of each of the`static_offsets`,
251- /// `static_shape` and `static_strides` attributes.
252- std::array<unsigned, 3> getArrayAttrMaxRanks() {
253- unsigned rank;
254- if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
255- rank = ty.getRank();
256- } else {
257- rank = (unsigned)getMixedOffsets().size();
258- }
259- return {rank, rank, rank};
214+ return getMixedValues(statics, getStrides(), getContext());
260215 }
261216
262217 /// Return the number of leading operands before the `offsets`,
@@ -314,15 +269,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
314269 }];
315270
316271 let assemblyFormat = [{
317- $TensorDesc ``
318- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
272+ $TensorDesc ``
273+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
319274 prop-dict attr-dict `:` qualified(type($TensorDesc))
320275 }];
321276
322277 let builders = [
323- OpBuilder<(ins "Value": $TensorDesc,
324- "xegpu::CachePolicyAttr": $l1_hint,
325- "xegpu::CachePolicyAttr": $l2_hint,
278+ OpBuilder<(ins "Value": $TensorDesc,
279+ "xegpu::CachePolicyAttr": $l1_hint,
280+ "xegpu::CachePolicyAttr": $l2_hint,
326281 "xegpu::CachePolicyAttr": $l3_hint)>
327282 ];
328283
@@ -370,7 +325,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
370325
371326 let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
372327 Variadic<Index>: $offsets,
373- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
328+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
374329 OptionalAttr<UnitAttr>: $packed,
375330 OptionalAttr<DenseI64ArrayAttr>: $transpose,
376331 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +345,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
390345 }];
391346
392347 let assemblyFormat = [{
393- $TensorDesc ``
394- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
348+ $TensorDesc ``
349+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
395350 prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
396351 }];
397352
398353 let builders = [
399- OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
354+ OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
400355 "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
401- "xegpu::CachePolicyAttr": $l1_hint,
402- "xegpu::CachePolicyAttr": $l2_hint,
356+ "xegpu::CachePolicyAttr": $l1_hint,
357+ "xegpu::CachePolicyAttr": $l2_hint,
403358 "xegpu::CachePolicyAttr": $l3_hint)>
404359 ];
405360
@@ -442,7 +397,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
442397 let arguments = (ins XeGPU_ValueType: $value,
443398 XeGPU_TensorDesc: $TensorDesc,
444399 Variadic<Index>: $offsets,
445- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
400+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
446401 OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
447402 OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
448403 OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +413,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
458413 }];
459414
460415 let assemblyFormat = [{
461- $value `,`
462- $TensorDesc ``
463- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
416+ $value `,`
417+ $TensorDesc ``
418+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
464419 prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
465420 }];
466421
467422 let builders = [
468- OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
469- "xegpu::CachePolicyAttr": $l1_hint,
470- "xegpu::CachePolicyAttr": $l2_hint,
423+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
424+ "xegpu::CachePolicyAttr": $l1_hint,
425+ "xegpu::CachePolicyAttr": $l2_hint,
471426 "xegpu::CachePolicyAttr": $l3_hint)>
472427 ];
473428
@@ -635,12 +590,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
635590 l3_hint = #xegpu.cache_hint<cached>}
636591 : !xegpu.tensor_desc<16xf16>
637592 ```
638-
593+
639594 Example 2:
640595 A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641596 It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642597 The source operand could be a raw pointer (uint64_t).
643- Please refer to create_tdesc for the restriction of memref.
598+ Please refer to create_tdesc for the restriction of memref.
644599 ```mlir
645600 %a = memref.alloc() : memref<1024xf32>
646601 %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +631,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
676631 }];
677632
678633 let assemblyFormat = [{
679- $source
634+ $source
680635 (`[` $offsets^ `]`)?
681636 prop-dict
682- attr-dict `:` type(operands)
637+ attr-dict `:` type(operands)
683638 }];
684-
639+
685640 let builders = [
686641 OpBuilder<(ins "Value": $source,
687- "xegpu::CachePolicyAttr": $l1_hint,
688- "xegpu::CachePolicyAttr": $l2_hint,
642+ "xegpu::CachePolicyAttr": $l1_hint,
643+ "xegpu::CachePolicyAttr": $l2_hint,
689644 "xegpu::CachePolicyAttr": $l3_hint)>
690645 ];
691646
@@ -723,7 +678,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
723678 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
724679 vector<16xi1> -> vector<16x8xf32>
725680 ```
726-
681+
727682 Example 3 (SIMT mode):
728683 ```mlir
729684 %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +687,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
732687 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
733688 vector<16xi1> -> vector<8xf32>
734689 ```
735-
690+
736691 Example 4:
737692 A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738693 It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739694 The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740- for the restriction of memref.
695+ for the restriction of memref.
741696 ```mlir
742697 %a = memref.alloc() : memref<1024xf32>
743698 %offsets = vector.step : vector<16xindex>
@@ -794,14 +749,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
794749 let assemblyFormat = [{
795750 $source
796751 (`[` $offsets^ `]`)? `,`
797- $mask prop-dict
752+ $mask prop-dict
798753 attr-dict `:` type(operands) `->` type($value)
799754 }];
800755
801756 let builders = [
802757 OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
803- "xegpu::CachePolicyAttr": $l1_hint,
804- "xegpu::CachePolicyAttr": $l2_hint,
758+ "xegpu::CachePolicyAttr": $l1_hint,
759+ "xegpu::CachePolicyAttr": $l2_hint,
805760 "xegpu::CachePolicyAttr": $l3_hint)>
806761 ];
807762
@@ -848,7 +803,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
848803 A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849804 It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850805 The dest operand could be a raw pointer (uint64_t).
851- Please refer to create_tdesc for the restriction of memref.
806+ Please refer to create_tdesc for the restriction of memref.
852807 ```mlir
853808 %a = memref.alloc() : memref<1024xf32>
854809 %val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +856,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
901856 $value `,`
902857 $dest
903858 (`[` $offsets^ `]`)? `,`
904- $mask
905- prop-dict
859+ $mask
860+ prop-dict
906861 attr-dict `:` type(operands)
907862 }];
908863
909864 let builders = [
910865 OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
911- "xegpu::CachePolicyAttr": $l1_hint,
912- "xegpu::CachePolicyAttr": $l2_hint,
866+ "xegpu::CachePolicyAttr": $l1_hint,
867+ "xegpu::CachePolicyAttr": $l2_hint,
913868 "xegpu::CachePolicyAttr": $l3_hint)>
914869 ];
915870
0 commit comments