Skip to content

Commit 90944b8

Browse files
authored
[MLIR][XeGPU] Add offset operands to load_nd/store_nd/prefetch_nd (#149424)
This PR allows load_nd/store_nd/prefetch_nd to take an additional offset operand. It is based on this PR #148335. Now user can create a nd_tdesc with no offset, and instead set the offset with the load_nd operation.
1 parent 01b23c8 commit 90944b8

File tree

7 files changed

+255
-21
lines changed

7 files changed

+255
-21
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,22 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
2929
void printProperties(::mlir::MLIRContext *ctx,
3030
::mlir::OpAsmPrinter &p, const Properties &prop,
3131
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
32-
Attribute propAttr = getPropertiesAsAttr(ctx, prop);
33-
if (propAttr)
34-
p << "<" << propAttr << ">";
32+
33+
DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
34+
35+
// filter out the elidedProps from propAttr, and get the resultAttr
36+
mlir::SmallVector<mlir::NamedAttribute> filteredAttrs;
37+
if (propAttr) {
38+
for (auto namedAttr : propAttr.getValue()) {
39+
if (llvm::is_contained(elidedProps, namedAttr.getName().strref()))
40+
continue;
41+
filteredAttrs.push_back(namedAttr);
42+
}
43+
}
44+
45+
if (!filteredAttrs.empty()) {
46+
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
47+
}
3548
}
3649

3750
static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
@@ -288,6 +301,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
288301
}];
289302

290303
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
304+
Variadic<Index>: $offsets,
305+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
291306
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
292307
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
293308
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -298,7 +313,18 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
298313
}
299314
}];
300315

301-
let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
316+
let assemblyFormat = [{
317+
$TensorDesc ``
318+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
319+
prop-dict attr-dict `:` qualified(type($TensorDesc))
320+
}];
321+
322+
let builders = [
323+
OpBuilder<(ins "Value": $TensorDesc,
324+
"xegpu::CachePolicyAttr": $l1_hint,
325+
"xegpu::CachePolicyAttr": $l2_hint,
326+
"xegpu::CachePolicyAttr": $l3_hint)>
327+
];
302328

303329
let hasVerifier = 1;
304330
}
@@ -343,6 +369,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
343369
}];
344370

345371
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
372+
Variadic<Index>: $offsets,
373+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
346374
OptionalAttr<UnitAttr>: $packed,
347375
OptionalAttr<DenseI64ArrayAttr>: $transpose,
348376
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -361,7 +389,20 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
361389
}
362390
}];
363391

364-
let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)";
392+
let assemblyFormat = [{
393+
$TensorDesc ``
394+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
395+
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
396+
}];
397+
398+
let builders = [
399+
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
400+
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
401+
"xegpu::CachePolicyAttr": $l1_hint,
402+
"xegpu::CachePolicyAttr": $l2_hint,
403+
"xegpu::CachePolicyAttr": $l3_hint)>
404+
];
405+
365406
let hasVerifier = 1;
366407
}
367408

@@ -400,6 +441,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
400441

401442
let arguments = (ins XeGPU_ValueType: $value,
402443
XeGPU_TensorDesc: $TensorDesc,
444+
Variadic<Index>: $offsets,
445+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
403446
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
404447
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
405448
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -414,8 +457,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
414457
}
415458
}];
416459

417-
let assemblyFormat = [{$value `,` $TensorDesc prop-dict attr-dict
418-
`:` type($value) `,` qualified(type($TensorDesc))}];
460+
let assemblyFormat = [{
461+
$value `,`
462+
$TensorDesc ``
463+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
464+
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
465+
}];
466+
467+
let builders = [
468+
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
469+
"xegpu::CachePolicyAttr": $l1_hint,
470+
"xegpu::CachePolicyAttr": $l2_hint,
471+
"xegpu::CachePolicyAttr": $l3_hint)>
472+
];
473+
474+
419475
let hasVerifier = 1;
420476
}
421477

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,18 +329,30 @@ ParseResult parseOptionalDynamicIndexList(
329329
return success();
330330
}
331331

332-
void printOptionalDynamicIndexList(
333-
OpAsmPrinter &printer, Operation *op, OperandRange values,
334-
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
335-
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
332+
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
333+
OperandRange values,
334+
DenseI64ArrayAttr integers) {
335+
336+
if (!integers)
337+
return;
336338

337339
return printDynamicIndexList(printer, op, values, integers,
338-
/*scalableFlags=*/{}, valueTypes, delimiter);
340+
/*scalableFlags=*/{}, {},
341+
AsmParser::Delimiter::Square);
339342
}
340-
341343
//===----------------------------------------------------------------------===//
342344
// XeGPU_PrefetchNdOp
343345
//===----------------------------------------------------------------------===//
346+
347+
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
348+
Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
349+
xegpu::CachePolicyAttr l2_hint,
350+
xegpu::CachePolicyAttr l3_hint) {
351+
352+
return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
353+
l1_hint, l2_hint, l3_hint);
354+
}
355+
344356
LogicalResult PrefetchNdOp::verify() {
345357
auto tdescTy = getTensorDescType();
346358
if (tdescTy.isScattered())
@@ -355,12 +367,34 @@ LogicalResult PrefetchNdOp::verify() {
355367
if (!isReadHintOrNone(getL3HintAttr()))
356368
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
357369

370+
int64_t tDescRank = tdescTy.getRank();
371+
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
372+
int64_t constOffsetSize =
373+
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
374+
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
375+
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
376+
return emitOpError(
377+
"Mismatched ranks between offsets and tensor descriptor");
378+
358379
return success();
359380
}
360381

361382
//===----------------------------------------------------------------------===//
362383
// XeGPU_LoadNdOp
363384
//===----------------------------------------------------------------------===//
385+
386+
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
387+
Value tensorDesc, UnitAttr packed,
388+
DenseI64ArrayAttr transpose,
389+
xegpu::CachePolicyAttr l1_hint,
390+
xegpu::CachePolicyAttr l2_hint,
391+
xegpu::CachePolicyAttr l3_hint) {
392+
393+
return build(builder, state, retType, tensorDesc, ValueRange(),
394+
DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
395+
l3_hint);
396+
}
397+
364398
LogicalResult LoadNdOp::verify() {
365399
auto tdescTy = getTensorDescType();
366400
auto valueTy = getType();
@@ -442,12 +476,31 @@ LogicalResult LoadNdOp::verify() {
442476
<< " is not consistent with tensor descriptor "
443477
<< tdescTy;
444478

479+
int64_t tDescRank = tdescTy.getRank();
480+
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
481+
int64_t constOffsetSize =
482+
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
483+
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
484+
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
485+
return emitOpError(
486+
"Mismatched ranks between offsets and tensor descriptor");
487+
445488
return success();
446489
}
447490

448491
//===----------------------------------------------------------------------===//
449492
// XeGPU_StoreNdOp
450493
//===----------------------------------------------------------------------===//
494+
495+
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
496+
Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
497+
xegpu::CachePolicyAttr l2_hint,
498+
xegpu::CachePolicyAttr l3_hint) {
499+
500+
return build(builder, state, value, tensorDesc, ValueRange(),
501+
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
502+
}
503+
451504
LogicalResult StoreNdOp::verify() {
452505
auto dstTy = getTensorDescType(); // Tile
453506
auto valTy = getValueType(); // Vector
@@ -502,6 +555,15 @@ LogicalResult StoreNdOp::verify() {
502555
<< " is not consistent with tensor descriptor "
503556
<< dstTy;
504557

558+
int64_t tDescRank = dstTy.getRank();
559+
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
560+
int64_t constOffsetSize =
561+
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
562+
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
563+
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
564+
return emitOpError(
565+
"Mismatched ranks between offsets and tensor descriptor");
566+
505567
return success();
506568
}
507569

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,10 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
352352
if (!storeOp)
353353
return failure();
354354

355+
int64_t offsetSize = static_cast<int64_t>(storeOp.getOffsets().size());
356+
if ((offsetSize != 0) || storeOp.getConstOffsetsAttr())
357+
return failure();
358+
355359
xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
356360
xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
357361
if (!layout)
@@ -464,6 +468,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
464468
warpOp, "warp result is not a xegpu::LoadNd op");
465469

466470
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
471+
472+
int64_t offsetSize = static_cast<int64_t>(loadOp.getOffsets().size());
473+
if ((offsetSize != 0) || loadOp.getConstOffsetsAttr())
474+
return failure();
475+
467476
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
468477
xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
469478
if (!layout)
@@ -767,6 +776,11 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
767776
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
768777
if (!prefetchOp)
769778
return failure();
779+
780+
int64_t offsetSize = static_cast<int64_t>(prefetchOp.getOffsets().size());
781+
if ((offsetSize != 0) || prefetchOp.getConstOffsetsAttr())
782+
return failure();
783+
770784
xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
771785
if (!layout)
772786
return rewriter.notifyMatchFailure(

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
218218
if (!targetShape)
219219
return failure();
220220

221+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
222+
if ((offsetSize != 0) || op.getConstOffsetsAttr())
223+
return failure();
224+
221225
SmallVector<Type> convertedTdescTypes =
222226
getUnrolledTypes(tdescTy, *targetShape);
223227
SmallVector<Value> convertedTdesc = pack(
@@ -245,6 +249,10 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
245249
if (!targetShape)
246250
return failure();
247251

252+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
253+
if ((offsetSize != 0) || op.getConstOffsetsAttr())
254+
return failure();
255+
248256
Type elemTy = tdescTy.getElementType();
249257
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
250258

@@ -279,6 +287,10 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
279287
if (!targetShape)
280288
return failure();
281289

290+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
291+
if ((offsetSize != 0) || op.getConstOffsetsAttr())
292+
return failure();
293+
282294
SmallVector<Type> convertedValTypes =
283295
getUnrolledTypes(valueTy, *targetShape);
284296
SmallVector<Type> convertedTdescTypes =

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
219219
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
220220
ConversionPatternRewriter &rewriter) const override {
221221
SmallVector<Value> newLoadOps;
222+
223+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
224+
if ((offsetSize != 0) || op.getConstOffsetsAttr())
225+
return failure();
226+
222227
for (auto src : adaptor.getTensorDesc()) {
223228
xegpu::TensorDescType tdescTy =
224229
dyn_cast<xegpu::TensorDescType>(src.getType());
@@ -241,6 +246,11 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
241246
LogicalResult
242247
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
243248
ConversionPatternRewriter &rewriter) const override {
249+
250+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
251+
if ((offsetSize != 0) || op.getConstOffsetsAttr())
252+
return failure();
253+
244254
for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
245255
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
246256
op.getL2HintAttr(), op.getL3HintAttr());
@@ -323,6 +333,11 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
323333
LogicalResult
324334
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
325335
ConversionPatternRewriter &rewriter) const override {
336+
337+
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
338+
if ((offsetSize != 0) || op.getConstOffsetsAttr())
339+
return failure();
340+
326341
for (auto src : adaptor.getTensorDesc())
327342
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src,
328343
op->getAttrs());

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,31 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
132132
return
133133
}
134134

135+
// -----
136+
func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
137+
%1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
138+
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
139+
%2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
140+
return
141+
}
142+
143+
// -----
144+
func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
145+
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
146+
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
147+
xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
148+
return
149+
}
150+
151+
// -----
152+
func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
153+
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
154+
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
155+
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
156+
xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
157+
return
158+
}
159+
135160
// -----
136161
func.func @load_nd_layout(%src: memref<24x32xf32>) {
137162
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>

0 commit comments

Comments
 (0)