Skip to content

Commit fab2b22

Browse files
authored
[mlir][xegpu] Remove OffsetSizeAndStrideOpInterface from CreateNdDescOp (#152773)
1 parent 327c64c commit fab2b22

File tree

4 files changed

+94
-142
lines changed

4 files changed

+94
-142
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 63 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
2929
void printProperties(::mlir::MLIRContext *ctx,
3030
::mlir::OpAsmPrinter &p, const Properties &prop,
3131
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
32-
32+
3333
DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
3434

3535
// filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
4343
}
4444

4545
if (!filteredAttrs.empty()) {
46-
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
46+
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
4747
}
4848
}
4949

@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
6060
}
6161

6262

63-
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
64-
AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
63+
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
6564

6665
let summary = "Create nd-tensor descriptor operation";
6766
let description = [{
@@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
181180
return getType().getShape();
182181
}
183182

184-
/// wrapper for matching with OffsetSizeAndStrideOpInterface
185-
OperandRange getSizes() {
186-
return getShape();
183+
SmallVector<OpFoldResult> getMixedOffsets() {
184+
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
185+
auto dynamics = getOffsets();
186+
if (statics.size() == 0 && dynamics.size() == 0)
187+
return {};
188+
return getMixedValues(statics, dynamics, getContext());
187189
}
188190

189-
ArrayRef<int64_t> getStaticOffsets(){
190-
auto attr = getConstOffsetsAttr();
191-
192-
if (attr)
193-
return attr;
191+
SmallVector<OpFoldResult> getMixedSizes() {
192+
SmallVector<int64_t> statics;
194193

195-
int64_t rank = getMixedSizes().size();
196-
197-
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
194+
/// Get the static sizes/shape, the value passed to const_shape
195+
/// will overide the value in memref shape.
196+
if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
197+
statics = llvm::to_vector(memrefTy.getShape());
198+
if (auto attr = getConstShapeAttr())
199+
statics = llvm::to_vector(attr.asArrayRef());
198200

199-
attr = getConstOffsetsAttr();
200-
return attr;
201+
return getMixedValues(statics, getShape(), getContext());
201202
}
202203

203-
/// wrapper for matching with OffsetSizeAndStrideOpInterface
204-
/// If source is IntegerType or `const_shape` is filled,
205-
/// it will return `const_shape`, such that mixes of `shape`
206-
/// and `const_shape` will be used to represent the shape of
207-
/// source operand. They overide static shape from source memref type.
208-
ArrayRef<int64_t> getStaticSizes() {
209-
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
210-
static llvm::SmallVector<int64_t, 4> emptyShape;
211-
212-
auto attr = getConstShapeAttr();
213-
if (attr)
214-
return attr;
215-
216-
if (llvm::isa<IntegerType>(getSourceType()))
217-
return emptyShape;
218-
219-
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
220-
assert(memrefType && "Incorrect use of getStaticSizes");
221-
return memrefType.getShape();
222-
}
204+
SmallVector<OpFoldResult> getMixedStrides() {
205+
SmallVector<int64_t> statics;
223206

224-
/// wrapper for matching with OffsetSizeAndStrideOpInterface
225-
/// If source is IntegerType or `const_strides` is filled, it
226-
/// will return `const_strides`, such that mixes of `strides`
227-
/// and `const_strides` will be used to represent the strides of
228-
/// source operand. They overide static strides from source memref type.
229-
ArrayRef<int64_t> getStaticStrides() {
230-
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
231-
static llvm::SmallVector<int64_t, 4> emptyStrides;
232-
233-
auto attr = getConstStridesAttr();
234-
if (attr)
235-
return attr;
236-
237-
if (llvm::isa<IntegerType>(getSourceType()))
238-
return emptyStrides;
239-
240-
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
241-
assert(memrefType && "Incorrect use of getStaticStrides");
242-
auto [strides, _] = memrefType.getStridesAndOffset();
243-
// reuse the storage of ConstStridesAttr since strides from
244-
// memref is not persistant
245-
setConstStrides(strides);
246-
attr = getConstStridesAttr();
247-
return attr;
248-
}
207+
/// Get the static strides, the value passed to const_strides
208+
/// will overide the value in memref.
209+
if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
210+
statics = memrefTy.getStridesAndOffset().first;
211+
if (auto attr = getConstStridesAttr())
212+
statics = llvm::to_vector(attr.asArrayRef());
249213

250-
/// Return the expected rank of each of the`static_offsets`,
251-
/// `static_shape` and `static_strides` attributes.
252-
std::array<unsigned, 3> getArrayAttrMaxRanks() {
253-
unsigned rank;
254-
if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
255-
rank = ty.getRank();
256-
} else {
257-
rank = (unsigned)getMixedOffsets().size();
258-
}
259-
return {rank, rank, rank};
214+
return getMixedValues(statics, getStrides(), getContext());
260215
}
261216

262217
/// Return the number of leading operands before the `offsets`,
@@ -314,15 +269,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
314269
}];
315270

316271
let assemblyFormat = [{
317-
$TensorDesc ``
318-
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
272+
$TensorDesc ``
273+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
319274
prop-dict attr-dict `:` qualified(type($TensorDesc))
320275
}];
321276

322277
let builders = [
323-
OpBuilder<(ins "Value": $TensorDesc,
324-
"xegpu::CachePolicyAttr": $l1_hint,
325-
"xegpu::CachePolicyAttr": $l2_hint,
278+
OpBuilder<(ins "Value": $TensorDesc,
279+
"xegpu::CachePolicyAttr": $l1_hint,
280+
"xegpu::CachePolicyAttr": $l2_hint,
326281
"xegpu::CachePolicyAttr": $l3_hint)>
327282
];
328283

@@ -370,7 +325,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
370325

371326
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
372327
Variadic<Index>: $offsets,
373-
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
328+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
374329
OptionalAttr<UnitAttr>: $packed,
375330
OptionalAttr<DenseI64ArrayAttr>: $transpose,
376331
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +345,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
390345
}];
391346

392347
let assemblyFormat = [{
393-
$TensorDesc ``
394-
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
348+
$TensorDesc ``
349+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
395350
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
396351
}];
397352

398353
let builders = [
399-
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
354+
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
400355
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
401-
"xegpu::CachePolicyAttr": $l1_hint,
402-
"xegpu::CachePolicyAttr": $l2_hint,
356+
"xegpu::CachePolicyAttr": $l1_hint,
357+
"xegpu::CachePolicyAttr": $l2_hint,
403358
"xegpu::CachePolicyAttr": $l3_hint)>
404359
];
405360

@@ -442,7 +397,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
442397
let arguments = (ins XeGPU_ValueType: $value,
443398
XeGPU_TensorDesc: $TensorDesc,
444399
Variadic<Index>: $offsets,
445-
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
400+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
446401
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
447402
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
448403
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +413,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
458413
}];
459414

460415
let assemblyFormat = [{
461-
$value `,`
462-
$TensorDesc ``
463-
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
416+
$value `,`
417+
$TensorDesc ``
418+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
464419
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
465420
}];
466421

467422
let builders = [
468-
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
469-
"xegpu::CachePolicyAttr": $l1_hint,
470-
"xegpu::CachePolicyAttr": $l2_hint,
423+
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
424+
"xegpu::CachePolicyAttr": $l1_hint,
425+
"xegpu::CachePolicyAttr": $l2_hint,
471426
"xegpu::CachePolicyAttr": $l3_hint)>
472427
];
473428

@@ -635,12 +590,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
635590
l3_hint = #xegpu.cache_hint<cached>}
636591
: !xegpu.tensor_desc<16xf16>
637592
```
638-
593+
639594
Example 2:
640595
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641596
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642597
The source operand could be a raw pointer (uint64_t).
643-
Please refer to create_tdesc for the restriction of memref.
598+
Please refer to create_tdesc for the restriction of memref.
644599
```mlir
645600
%a = memref.alloc() : memref<1024xf32>
646601
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +631,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
676631
}];
677632

678633
let assemblyFormat = [{
679-
$source
634+
$source
680635
(`[` $offsets^ `]`)?
681636
prop-dict
682-
attr-dict `:` type(operands)
637+
attr-dict `:` type(operands)
683638
}];
684-
639+
685640
let builders = [
686641
OpBuilder<(ins "Value": $source,
687-
"xegpu::CachePolicyAttr": $l1_hint,
688-
"xegpu::CachePolicyAttr": $l2_hint,
642+
"xegpu::CachePolicyAttr": $l1_hint,
643+
"xegpu::CachePolicyAttr": $l2_hint,
689644
"xegpu::CachePolicyAttr": $l3_hint)>
690645
];
691646

@@ -723,7 +678,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
723678
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
724679
vector<16xi1> -> vector<16x8xf32>
725680
```
726-
681+
727682
Example 3 (SIMT mode):
728683
```mlir
729684
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +687,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
732687
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
733688
vector<16xi1> -> vector<8xf32>
734689
```
735-
690+
736691
Example 4:
737692
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738693
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739694
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740-
for the restriction of memref.
695+
for the restriction of memref.
741696
```mlir
742697
%a = memref.alloc() : memref<1024xf32>
743698
%offsets = vector.step : vector<16xindex>
@@ -794,14 +749,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
794749
let assemblyFormat = [{
795750
$source
796751
(`[` $offsets^ `]`)? `,`
797-
$mask prop-dict
752+
$mask prop-dict
798753
attr-dict `:` type(operands) `->` type($value)
799754
}];
800755

801756
let builders = [
802757
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
803-
"xegpu::CachePolicyAttr": $l1_hint,
804-
"xegpu::CachePolicyAttr": $l2_hint,
758+
"xegpu::CachePolicyAttr": $l1_hint,
759+
"xegpu::CachePolicyAttr": $l2_hint,
805760
"xegpu::CachePolicyAttr": $l3_hint)>
806761
];
807762

@@ -848,7 +803,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
848803
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849804
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850805
The dest operand could be a raw pointer (uint64_t).
851-
Please refer to create_tdesc for the restriction of memref.
806+
Please refer to create_tdesc for the restriction of memref.
852807
```mlir
853808
%a = memref.alloc() : memref<1024xf32>
854809
%val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +856,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
901856
$value `,`
902857
$dest
903858
(`[` $offsets^ `]`)? `,`
904-
$mask
905-
prop-dict
859+
$mask
860+
prop-dict
906861
attr-dict `:` type(operands)
907862
}];
908863

909864
let builders = [
910865
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
911-
"xegpu::CachePolicyAttr": $l1_hint,
912-
"xegpu::CachePolicyAttr": $l2_hint,
866+
"xegpu::CachePolicyAttr": $l1_hint,
867+
"xegpu::CachePolicyAttr": $l2_hint,
913868
"xegpu::CachePolicyAttr": $l3_hint)>
914869
];
915870

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
265265
}
266266

267267
LogicalResult CreateNdDescOp::verify() {
268-
auto rank = (int64_t)getMixedOffsets().size();
269-
bool invalidRank = false;
268+
size_t rank = getMixedSizes().size();
269+
bool invalidRank = rank != getMixedStrides().size();
270270
bool invalidElemTy = false;
271271

272272
// Memory space of created TensorDesc should match with the source.
@@ -280,31 +280,28 @@ LogicalResult CreateNdDescOp::verify() {
280280
<< " Source: " << srcMemorySpace
281281
<< ", TensorDesc: " << tdescMemorySpace;
282282

283+
if (size_t offsetRank = getMixedOffsets().size())
284+
invalidRank |= (offsetRank != rank);
285+
283286
// check source type matches the rank if it is a memref.
284287
// It also should have the same ElementType as TensorDesc.
285-
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
286-
if (memrefTy) {
287-
invalidRank |= (memrefTy.getRank() != rank);
288+
if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
288289
invalidElemTy |= memrefTy.getElementType() != getElementType();
289-
}
290290

291291
if (llvm::isa<IntegerType>(getSourceType())) {
292292
// strides and shape must present for integer source.
293293
if (getMixedStrides().empty() || getMixedSizes().empty())
294-
return emitOpError("Expecting strides and shape to be present for "
294+
return emitOpError("expecting strides and shape to be present for "
295295
"integer source.");
296296
}
297297

298-
// mismatches among shape, strides, and offsets are
299-
// already handeled by OffsetSizeAndStrideOpInterface.
300-
// So they are not check here.
301298
if (invalidRank)
302299
return emitOpError(
303300
"Expecting the rank of shape, strides, offsets, and source (if source "
304301
"is a memref) should match with each other.");
305302

306303
// check result TensorDesc rank
307-
if (getType().getRank() > rank)
304+
if (getType().getRank() > (int64_t)rank)
308305
return emitOpError(
309306
"Expecting the TensorDesc rank is not greater than the "
310307
"ranks of shape, strides, offsets or the memref source.");

0 commit comments

Comments
 (0)