Skip to content

Commit babf57e

Browse files
committed
revert xegpu def changes
Signed-off-by: dchigarev <[email protected]>
1 parent 173eb6d commit babf57e

File tree

3 files changed

+44
-80
lines changed

3 files changed

+44
-80
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -261,21 +261,6 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
261261
: !xegpu.tensor_desc<8x16xf16>
262262
```
263263

264-
The operation may take optional offsets for the tensor descriptor.
265-
The number of offsets must be greater than or equal to the rank of the tensor
266-
descriptor and less than or equal to the rank of the source memref.
267-
The offsets are applied to the innermost dimensions of the source memref.
268-
269-
Examples:
270-
```mlir
271-
%tdesc = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
272-
// memref[0, 0, %off0, %off1]
273-
xegpu.prefetch_nd %tdesc[%off0, %off1] : !xegpu.tensor_desc<8x16xf16>
274-
// memref[0, %off0, %off1, %off2]
275-
xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
276-
// memref[%off0, %off1, %off2, %off3]
277-
xegpu.prefetch_nd %tdesc[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf16>
278-
```
279264
}];
280265

281266
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -365,21 +350,6 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
365350
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
366351
```
367352

368-
The operation may take optional offsets for the tensor descriptor.
369-
The number of offsets must be greater than or equal to the rank of the tensor
370-
descriptor and less than or equal to the rank of the source memref.
371-
The offsets are applied to the innermost dimensions of the source memref.
372-
373-
Examples:
374-
```mlir
375-
%1 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
376-
// memref[0, 0, %off0, %off1]
377-
xegpu.load_nd %1[%off0, %off1] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
378-
// memref[0, %off0, %off1, %off2]
379-
xegpu.load_nd %1[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
380-
// memref[%off0, %off1, %off2, %off3]
381-
xegpu.load_nd %1[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
382-
```
383353

384354
}];
385355

@@ -475,21 +445,6 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
475445
: vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
476446
```
477447

478-
The operation may take optional offsets for the tensor descriptor.
479-
The number of offsets must be greater than or equal to the rank of the tensor
480-
descriptor and less than or equal to the rank of the source memref.
481-
The offsets are applied to the innermost dimensions of the source memref.
482-
483-
Examples:
484-
```mlir
485-
%2 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
486-
// memref[0, 0, %off0, %off1]
487-
xegpu.store_nd %3, %2[%off0, %off1] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
488-
// memref[0, %off0, %off1, %off2]
489-
xegpu.store_nd %3, %2[%off0, %off1, %off2] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
490-
// memref[%off0, %off1, %off2, %off3]
491-
xegpu.store_nd %3, %2[%off0, %off1, %off2, %off3] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
492-
```
493448

494449
}];
495450

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,6 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
121121
return success();
122122
}
123123

124-
// Verify that number of offsets matches either the source rank or the tdesc
125-
// rank.
126-
static LogicalResult
127-
isValidNdOffset(TypedValue<TensorDescType> tDesc,
128-
std::optional<llvm::ArrayRef<int64_t>> constOffsets,
129-
int64_t offsetSize,
130-
function_ref<InFlightDiagnostic()> emitError) {
131-
int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
132-
auto tDescRank = tDesc.getType().getRank();
133-
if (((offsetSize != 0) && (offsetSize < tDescRank)) ||
134-
((constOffsetSize != 0) && (constOffsetSize < tDescRank)))
135-
return emitError() << "Offsets rank cannot be smaller than tensor "
136-
"descriptor rank.";
137-
return success();
138-
}
139-
140124
static LogicalResult
141125
isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
142126
VectorType valueTy, int64_t chunkSize,
@@ -274,10 +258,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
274258
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
275259

276260
// if shape and strides are from Memref, we don't need attributes for them
277-
// to keep the IR print clean (only do so for full-static case, otherwise
278-
// printer would fail trying to print empty array-attr).
279-
if (staticShape == memrefShape && staticStrides == memrefStrides &&
280-
dynamicShape.empty() && dynamicStrides.empty()) {
261+
// to keep the IR print clean.
262+
if (staticShape == memrefShape && staticStrides == memrefStrides) {
281263
staticShapeAttr = DenseI64ArrayAttr();
282264
staticStridesAttr = DenseI64ArrayAttr();
283265
}
@@ -338,10 +320,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
338320
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
339321

340322
// if shape and strides are from Memref, we don't need attributes for them
341-
// to keep the IR print clean (only do so for full-static case, otherwise
342-
// printer would fail trying to print empty array-attr).
343-
if (staticShape == memrefShape && staticStrides == memrefStrides &&
344-
dynamicShape.empty() && dynamicStrides.empty()) {
323+
// to keep the IR print clean.
324+
if (staticShape == memrefShape && staticStrides == memrefStrides) {
345325
staticShapeAttr = DenseI64ArrayAttr();
346326
staticStridesAttr = DenseI64ArrayAttr();
347327
}
@@ -491,9 +471,16 @@ LogicalResult PrefetchNdOp::verify() {
491471
if (!isReadHintOrNone(getL3HintAttr()))
492472
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
493473

494-
auto tDesc = getTensorDesc();
495-
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
496-
[&]() { return emitOpError(); });
474+
int64_t tDescRank = tdescTy.getRank();
475+
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
476+
int64_t constOffsetSize =
477+
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
478+
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
479+
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
480+
return emitOpError(
481+
"Mismatched ranks between offsets and tensor descriptor");
482+
483+
return success();
497484
}
498485

499486
//===----------------------------------------------------------------------===//
@@ -609,9 +596,16 @@ LogicalResult LoadNdOp::verify() {
609596
<< " is not consistent with tensor descriptor "
610597
<< tdescTy;
611598

612-
auto tDesc = getTensorDesc();
613-
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
614-
[&]() { return emitOpError(); });
599+
int64_t tDescRank = tdescTy.getRank();
600+
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
601+
int64_t constOffsetSize =
602+
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
603+
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
604+
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
605+
return emitOpError(
606+
"Mismatched ranks between offsets and tensor descriptor");
607+
608+
return success();
615609
}
616610

617611
//===----------------------------------------------------------------------===//
@@ -696,9 +690,16 @@ LogicalResult StoreNdOp::verify() {
696690
<< " is not consistent with tensor descriptor "
697691
<< dstTy;
698692

699-
auto tDesc = getTensorDesc();
700-
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
701-
[&]() { return emitOpError(); });
693+
int64_t tDescRank = dstTy.getRank();
694+
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
695+
int64_t constOffsetSize =
696+
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
697+
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
698+
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
699+
return emitOpError(
700+
"Mismatched ranks between offsets and tensor descriptor");
701+
702+
return success();
702703
}
703704

704705
//===----------------------------------------------------------------------===//

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,18 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
132132
return
133133
}
134134

135+
// -----
136+
func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
137+
%1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
138+
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
139+
%2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
140+
return
141+
}
142+
135143
// -----
136144
func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
137145
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
138-
// expected-error@+1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
146+
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
139147
xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
140148
return
141149
}
@@ -144,7 +152,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
144152
func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
145153
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
146154
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
147-
// expected-error@+1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
155+
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
148156
xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
149157
return
150158
}

0 commit comments

Comments
 (0)