Skip to content

Commit 2bc35d8

Browse files
committed
remove OffsetSizeAndStrideOpInterface from CreateNdDescOp
1 parent 3769ce0 commit 2bc35d8

File tree

4 files changed

+93
-133
lines changed

4 files changed

+93
-133
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 65 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
2929
void printProperties(::mlir::MLIRContext *ctx,
3030
::mlir::OpAsmPrinter &p, const Properties &prop,
3131
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
32-
32+
3333
DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
3434

3535
// filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
4343
}
4444

4545
if (!filteredAttrs.empty()) {
46-
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
46+
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
4747
}
4848
}
4949

@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
6060
}
6161

6262

63-
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
64-
AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
63+
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
6564

6665
let summary = "Create nd-tensor descriptor operation";
6766
let description = [{
@@ -181,82 +180,40 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
181180
return getType().getShape();
182181
}
183182

184-
/// wrapper for matching with OffsetSizeAndStrideOpInterface
185-
OperandRange getSizes() {
186-
return getShape();
183+
SmallVector<OpFoldResult> getMixedOffsets() {
184+
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
185+
auto dynamics = getOffsets();
186+
if (statics.size() == 0 && dynamics.size() == 0)
187+
return {};
188+
return getMixedValues(statics, dynamics, getContext());
187189
}
188190

189-
ArrayRef<int64_t> getStaticOffsets(){
190-
auto attr = getConstOffsetsAttr();
191-
192-
if (attr)
193-
return attr;
191+
SmallVector<OpFoldResult> getMixedSizes() {
192+
Builder b(getContext());
193+
SmallVector<int64_t> statics;
194194

195-
int64_t rank = getMixedSizes().size();
196-
197-
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
195+
/// Get the static sizes/shape, the value passed to const_shape
196+
/// will overide the value in memref shape.
197+
if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
198+
statics = llvm::to_vector(memrefTy.getShape());
199+
if (auto attr = getConstShapeAttr())
200+
statics = llvm::to_vector(attr.asArrayRef());
198201

199-
attr = getConstOffsetsAttr();
200-
return attr;
202+
return getMixedValues(statics, getShape(), b);
201203
}
202204

203-
/// wrapper for matching with OffsetSizeAndStrideOpInterface
204-
/// If source is IntegerType or `const_shape` is filled,
205-
/// it will return `const_shape`, such that mixes of `shape`
206-
/// and `const_shape` will be used to represent the shape of
207-
/// source operand. They overide static shape from source memref type.
208-
ArrayRef<int64_t> getStaticSizes() {
209-
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
210-
static llvm::SmallVector<int64_t, 4> emptyShape;
211-
212-
auto attr = getConstShapeAttr();
213-
if (attr)
214-
return attr;
215-
216-
if (llvm::isa<IntegerType>(getSourceType()))
217-
return emptyShape;
218-
219-
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
220-
assert(memrefType && "Incorrect use of getStaticSizes");
221-
return memrefType.getShape();
222-
}
205+
SmallVector<OpFoldResult> getMixedStrides() {
206+
Builder b(getContext());
207+
SmallVector<int64_t> statics;
223208

224-
/// wrapper for matching with OffsetSizeAndStrideOpInterface
225-
/// If source is IntegerType or `const_strides` is filled, it
226-
/// will return `const_strides`, such that mixes of `strides`
227-
/// and `const_strides` will be used to represent the strides of
228-
/// source operand. They overide static strides from source memref type.
229-
ArrayRef<int64_t> getStaticStrides() {
230-
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
231-
static llvm::SmallVector<int64_t, 4> emptyStrides;
232-
233-
auto attr = getConstStridesAttr();
234-
if (attr)
235-
return attr;
236-
237-
if (llvm::isa<IntegerType>(getSourceType()))
238-
return emptyStrides;
239-
240-
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
241-
assert(memrefType && "Incorrect use of getStaticStrides");
242-
auto [strides, _] = memrefType.getStridesAndOffset();
243-
// reuse the storage of ConstStridesAttr since strides from
244-
// memref is not persistant
245-
setConstStrides(strides);
246-
attr = getConstStridesAttr();
247-
return attr;
248-
}
209+
/// Get the static strides, the value passed to const_strides
210+
/// will overide the value in memref.
211+
if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
212+
statics = memrefTy.getStridesAndOffset().first;
213+
if (auto attr = getConstStridesAttr())
214+
statics = llvm::to_vector(attr.asArrayRef());
249215

250-
/// Return the expected rank of each of the`static_offsets`,
251-
/// `static_shape` and `static_strides` attributes.
252-
std::array<unsigned, 3> getArrayAttrMaxRanks() {
253-
unsigned rank;
254-
if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
255-
rank = ty.getRank();
256-
} else {
257-
rank = (unsigned)getMixedOffsets().size();
258-
}
259-
return {rank, rank, rank};
216+
return getMixedValues(statics, getStrides(), b);
260217
}
261218

262219
/// Return the number of leading operands before the `offsets`,
@@ -314,15 +271,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
314271
}];
315272

316273
let assemblyFormat = [{
317-
$TensorDesc ``
318-
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
274+
$TensorDesc ``
275+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
319276
prop-dict attr-dict `:` qualified(type($TensorDesc))
320277
}];
321278

322279
let builders = [
323-
OpBuilder<(ins "Value": $TensorDesc,
324-
"xegpu::CachePolicyAttr": $l1_hint,
325-
"xegpu::CachePolicyAttr": $l2_hint,
280+
OpBuilder<(ins "Value": $TensorDesc,
281+
"xegpu::CachePolicyAttr": $l1_hint,
282+
"xegpu::CachePolicyAttr": $l2_hint,
326283
"xegpu::CachePolicyAttr": $l3_hint)>
327284
];
328285

@@ -370,7 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
370327

371328
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
372329
Variadic<Index>: $offsets,
373-
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
330+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
374331
OptionalAttr<UnitAttr>: $packed,
375332
OptionalAttr<DenseI64ArrayAttr>: $transpose,
376333
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +347,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
390347
}];
391348

392349
let assemblyFormat = [{
393-
$TensorDesc ``
394-
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
350+
$TensorDesc ``
351+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
395352
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
396353
}];
397354

398355
let builders = [
399-
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
356+
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
400357
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
401-
"xegpu::CachePolicyAttr": $l1_hint,
402-
"xegpu::CachePolicyAttr": $l2_hint,
358+
"xegpu::CachePolicyAttr": $l1_hint,
359+
"xegpu::CachePolicyAttr": $l2_hint,
403360
"xegpu::CachePolicyAttr": $l3_hint)>
404361
];
405362

@@ -442,7 +399,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
442399
let arguments = (ins XeGPU_ValueType: $value,
443400
XeGPU_TensorDesc: $TensorDesc,
444401
Variadic<Index>: $offsets,
445-
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
402+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
446403
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
447404
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
448405
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +415,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
458415
}];
459416

460417
let assemblyFormat = [{
461-
$value `,`
462-
$TensorDesc ``
463-
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
418+
$value `,`
419+
$TensorDesc ``
420+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
464421
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
465422
}];
466423

467424
let builders = [
468-
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
469-
"xegpu::CachePolicyAttr": $l1_hint,
470-
"xegpu::CachePolicyAttr": $l2_hint,
425+
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
426+
"xegpu::CachePolicyAttr": $l1_hint,
427+
"xegpu::CachePolicyAttr": $l2_hint,
471428
"xegpu::CachePolicyAttr": $l3_hint)>
472429
];
473430

@@ -635,12 +592,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
635592
l3_hint = #xegpu.cache_hint<cached>}
636593
: !xegpu.tensor_desc<16xf16>
637594
```
638-
595+
639596
Example 2:
640597
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641598
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642599
The source operand could be a raw pointer (uint64_t).
643-
Please refer to create_tdesc for the restriction of memref.
600+
Please refer to create_tdesc for the restriction of memref.
644601
```mlir
645602
%a = memref.alloc() : memref<1024xf32>
646603
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +633,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
676633
}];
677634

678635
let assemblyFormat = [{
679-
$source
636+
$source
680637
(`[` $offsets^ `]`)?
681638
prop-dict
682-
attr-dict `:` type(operands)
639+
attr-dict `:` type(operands)
683640
}];
684-
641+
685642
let builders = [
686643
OpBuilder<(ins "Value": $source,
687-
"xegpu::CachePolicyAttr": $l1_hint,
688-
"xegpu::CachePolicyAttr": $l2_hint,
644+
"xegpu::CachePolicyAttr": $l1_hint,
645+
"xegpu::CachePolicyAttr": $l2_hint,
689646
"xegpu::CachePolicyAttr": $l3_hint)>
690647
];
691648

@@ -723,7 +680,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
723680
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
724681
vector<16xi1> -> vector<16x8xf32>
725682
```
726-
683+
727684
Example 3 (SIMT mode):
728685
```mlir
729686
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +689,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
732689
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
733690
vector<16xi1> -> vector<8xf32>
734691
```
735-
692+
736693
Example 4:
737694
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738695
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739696
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740-
for the restriction of memref.
697+
for the restriction of memref.
741698
```mlir
742699
%a = memref.alloc() : memref<1024xf32>
743700
%offsets = vector.step : vector<16xindex>
@@ -794,14 +751,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
794751
let assemblyFormat = [{
795752
$source
796753
(`[` $offsets^ `]`)? `,`
797-
$mask prop-dict
754+
$mask prop-dict
798755
attr-dict `:` type(operands) `->` type($value)
799756
}];
800757

801758
let builders = [
802759
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
803-
"xegpu::CachePolicyAttr": $l1_hint,
804-
"xegpu::CachePolicyAttr": $l2_hint,
760+
"xegpu::CachePolicyAttr": $l1_hint,
761+
"xegpu::CachePolicyAttr": $l2_hint,
805762
"xegpu::CachePolicyAttr": $l3_hint)>
806763
];
807764

@@ -848,7 +805,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
848805
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849806
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850807
The dest operand could be a raw pointer (uint64_t).
851-
Please refer to create_tdesc for the restriction of memref.
808+
Please refer to create_tdesc for the restriction of memref.
852809
```mlir
853810
%a = memref.alloc() : memref<1024xf32>
854811
%val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +858,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
901858
$value `,`
902859
$dest
903860
(`[` $offsets^ `]`)? `,`
904-
$mask
905-
prop-dict
861+
$mask
862+
prop-dict
906863
attr-dict `:` type(operands)
907864
}];
908865

909866
let builders = [
910867
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
911-
"xegpu::CachePolicyAttr": $l1_hint,
912-
"xegpu::CachePolicyAttr": $l2_hint,
868+
"xegpu::CachePolicyAttr": $l1_hint,
869+
"xegpu::CachePolicyAttr": $l2_hint,
913870
"xegpu::CachePolicyAttr": $l3_hint)>
914871
];
915872

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
265265
}
266266

267267
LogicalResult CreateNdDescOp::verify() {
268-
auto rank = (int64_t)getMixedOffsets().size();
268+
int64_t rank = getMixedSizes().size();
269269
bool invalidRank = false;
270270
bool invalidElemTy = false;
271271

@@ -280,6 +280,9 @@ LogicalResult CreateNdDescOp::verify() {
280280
<< " Source: " << srcMemorySpace
281281
<< ", TensorDesc: " << tdescMemorySpace;
282282

283+
if (int64_t offsetRank = getMixedOffsets().size())
284+
invalidRank |= (offsetRank != rank);
285+
283286
// check source type matches the rank if it is a memref.
284287
// It also should have the same ElementType as TensorDesc.
285288
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
@@ -291,7 +294,7 @@ LogicalResult CreateNdDescOp::verify() {
291294
if (llvm::isa<IntegerType>(getSourceType())) {
292295
// strides and shape must present for integer source.
293296
if (getMixedStrides().empty() || getMixedSizes().empty())
294-
return emitOpError("Expecting strides and shape to be present for "
297+
return emitOpError("expecting strides and shape to be present for "
295298
"integer source.");
296299
}
297300

0 commit comments

Comments
 (0)