Skip to content

Commit 1373ffa

Browse files
committed
add optional offsets to nd load/store/prefetch
1 parent 64c7e7e commit 1373ffa

File tree

3 files changed

+140
-12
lines changed

3 files changed

+140
-12
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,22 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
2929
void printProperties(::mlir::MLIRContext *ctx,
3030
::mlir::OpAsmPrinter &p, const Properties &prop,
3131
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
32-
Attribute propAttr = getPropertiesAsAttr(ctx, prop);
33-
if (propAttr)
34-
p << "<" << propAttr << ">";
32+
33+
DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
34+
35+
// filter out the elidedProps from propAttr, and get the resultAttr
36+
mlir::SmallVector<mlir::NamedAttribute> filteredAttrs;
37+
if (propAttr) {
38+
for (auto namedAttr : propAttr.getValue()) {
39+
if (llvm::is_contained(elidedProps, namedAttr.getName().strref()))
40+
continue;
41+
filteredAttrs.push_back(namedAttr);
42+
}
43+
}
44+
45+
if (!filteredAttrs.empty()) {
46+
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
47+
}
3548
}
3649

3750
static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
@@ -288,6 +301,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
288301
}];
289302

290303
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
304+
Variadic<Index>: $offsets,
305+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
291306
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
292307
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
293308
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -298,7 +313,18 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
298313
}
299314
}];
300315

301-
let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
316+
let assemblyFormat = [{
317+
$TensorDesc ``
318+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
319+
prop-dict attr-dict `:` qualified(type($TensorDesc))
320+
}];
321+
322+
let builders = [
323+
OpBuilder<(ins "Value": $TensorDesc,
324+
"xegpu::CachePolicyAttr": $l1_hint,
325+
"xegpu::CachePolicyAttr": $l2_hint,
326+
"xegpu::CachePolicyAttr": $l3_hint)>
327+
];
302328

303329
let hasVerifier = 1;
304330
}
@@ -343,6 +369,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
343369
}];
344370

345371
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
372+
Variadic<Index>: $offsets,
373+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
346374
OptionalAttr<UnitAttr>: $packed,
347375
OptionalAttr<DenseI64ArrayAttr>: $transpose,
348376
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -361,7 +389,20 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
361389
}
362390
}];
363391

364-
let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)";
392+
let assemblyFormat = [{
393+
$TensorDesc ``
394+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
395+
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
396+
}];
397+
398+
let builders = [
399+
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
400+
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
401+
"xegpu::CachePolicyAttr": $l1_hint,
402+
"xegpu::CachePolicyAttr": $l2_hint,
403+
"xegpu::CachePolicyAttr": $l3_hint)>
404+
];
405+
365406
let hasVerifier = 1;
366407
}
367408

@@ -400,6 +441,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
400441

401442
let arguments = (ins XeGPU_ValueType: $value,
402443
XeGPU_TensorDesc: $TensorDesc,
444+
Variadic<Index>: $offsets,
445+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
403446
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
404447
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
405448
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -414,8 +457,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
414457
}
415458
}];
416459

417-
let assemblyFormat = [{$value `,` $TensorDesc prop-dict attr-dict
418-
`:` type($value) `,` qualified(type($TensorDesc))}];
460+
let assemblyFormat = [{
461+
$value `,`
462+
$TensorDesc ``
463+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
464+
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
465+
}];
466+
467+
let builders = [
468+
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
469+
"xegpu::CachePolicyAttr": $l1_hint,
470+
"xegpu::CachePolicyAttr": $l2_hint,
471+
"xegpu::CachePolicyAttr": $l3_hint)>
472+
];
473+
474+
419475
let hasVerifier = 1;
420476
}
421477

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,16 +331,24 @@ ParseResult parseOptionalDynamicIndexList(
331331

332332
void printOptionalDynamicIndexList(
333333
OpAsmPrinter &printer, Operation *op, OperandRange values,
334-
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
335-
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
334+
DenseI64ArrayAttr integers) {
336335

337-
return printDynamicIndexList(printer, op, values, integers,
338-
/*scalableFlags=*/{}, valueTypes, delimiter);
339-
}
336+
if (!integers)
337+
return;
340338

339+
return printDynamicIndexList(printer, op, values, integers,
340+
/*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
341+
}
341342
//===----------------------------------------------------------------------===//
342343
// XeGPU_PrefetchNdOp
343344
//===----------------------------------------------------------------------===//
345+
346+
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
347+
348+
return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
349+
350+
}
351+
344352
LogicalResult PrefetchNdOp::verify() {
345353
auto tdescTy = getTensorDescType();
346354
if (tdescTy.isScattered())
@@ -361,6 +369,13 @@ LogicalResult PrefetchNdOp::verify() {
361369
//===----------------------------------------------------------------------===//
362370
// XeGPU_LoadNdOp
363371
//===----------------------------------------------------------------------===//
372+
373+
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
374+
375+
return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, l3_hint);
376+
377+
}
378+
364379
LogicalResult LoadNdOp::verify() {
365380
auto tdescTy = getTensorDescType();
366381
auto valueTy = getType();
@@ -448,6 +463,13 @@ LogicalResult LoadNdOp::verify() {
448463
//===----------------------------------------------------------------------===//
449464
// XeGPU_StoreNdOp
450465
//===----------------------------------------------------------------------===//
466+
467+
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
468+
469+
return build(builder, state, value, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
470+
471+
}
472+
451473
LogicalResult StoreNdOp::verify() {
452474
auto dstTy = getTensorDescType(); // Tile
453475
auto valTy = getValueType(); // Vector

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ gpu.func @prefetch_nd_2(%src: memref<8x24x32x48x64xf16>) {
121121
gpu.return
122122
}
123123

124+
// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
125+
gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
126+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
127+
%1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
128+
// CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
129+
xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
130+
gpu.return
131+
}
132+
124133
// CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
125134
gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
126135
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -260,6 +269,15 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
260269
gpu.return
261270
}
262271

272+
// CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
273+
gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>) {
274+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
275+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
276+
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
277+
%2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
278+
gpu.return
279+
}
280+
263281
// CHECK: func @simt_load_nd_8(%[[arg0:.*]]: memref<24x32xf32>) {
264282
gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
265283
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
@@ -269,6 +287,16 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
269287
gpu.return
270288
}
271289

290+
291+
// CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
292+
gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) {
293+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
294+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
295+
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
296+
%2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
297+
gpu.return
298+
}
299+
272300
// CHECK: func @subgroup_store_nd(%[[arg0:.*]]: memref<24x32xf16>) {
273301
gpu.func @subgroup_store_nd(%dst: memref<24x32xf16>) {
274302
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -293,6 +321,17 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
293321

294322
// CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
295323
gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
324+
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
325+
%1 = arith.constant dense<1.0>: vector<32xf16>
326+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
327+
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
328+
// CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
329+
xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
330+
gpu.return
331+
}
332+
333+
// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
334+
gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
296335
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
297336
%1 = arith.constant dense<1.0>: vector<32xf16>
298337
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
@@ -313,6 +352,17 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
313352
gpu.return
314353
}
315354

355+
// CHECK: func @simt_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
356+
gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) {
357+
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
358+
%1 = arith.constant dense<1.0>: vector<2xf16>
359+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
360+
%2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
361+
// CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16>
362+
xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16>
363+
gpu.return
364+
}
365+
316366
// CHECK: gpu.func @update_nd_tdesc(%[[arg0:.*]]: memref<24x32xf32>) {
317367
gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) {
318368
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

0 commit comments

Comments
 (0)