Skip to content

Commit 22867ff

Browse files
authored
Add cache attributes to xetile.prefetch_tile (#749)
1 parent 1762be9 commit 22867ff

File tree

4 files changed

+89
-8
lines changed

4 files changed

+89
-8
lines changed

include/imex/Dialect/XeTile/IR/XeTileAttrs.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,29 @@ def XeTile_AtomicRMWKindAttr : I64EnumAttr<
108108
let cppNamespace = "::imex::xetile";
109109
}
110110

111+
//===----------------------------------------------------------------------===//
112+
// XeTile Cache Enums.
113+
//===----------------------------------------------------------------------===//
114+
def XeTile_CachePolicyCached: I32EnumAttrCase<"CACHED", 0, "cached">; // valid for read and write
115+
def XeTile_CachePolicyUncached: I32EnumAttrCase<"UNCACHED", 1, "uncached">; // valid for read and write
116+
def XeTile_CachePolicyStreaming: I32EnumAttrCase<"STREAMING", 2, "streaming">; // valid for read only
117+
def XeTile_CachePolicyInvalid: I32EnumAttrCase<"READ_INVALIDATE", 3, "read_invalidate">; // valid for read only
118+
def XeTile_CachePolicyWriteBack: I32EnumAttrCase<"WRITE_BACK", 4, "write_back">; // valid for write only
119+
def XeTile_CachePolicyWriteThrough: I32EnumAttrCase<"WRITE_THROUGH", 5, "write_through">; // valid for write only
120+
121+
def XeTile_CachePolicyEnums : I32EnumAttr<"CachePolicy", "Cache policy",
122+
[XeTile_CachePolicyCached, XeTile_CachePolicyUncached,
123+
XeTile_CachePolicyStreaming, XeTile_CachePolicyInvalid,
124+
XeTile_CachePolicyWriteBack, XeTile_CachePolicyWriteThrough]> {
125+
let genSpecializedAttr = 0;
126+
let cppNamespace = "::imex::xetile";
127+
}
128+
129+
def XeTile_CacheHintAttr
130+
: EnumAttr<XeTile_Dialect, XeTile_CachePolicyEnums, "cache_hint"> {
131+
let summary = [{Describe the cache settings for prefetch/load/store operators}];
132+
let assemblyFormat = "`<` $value `>`";
133+
}
134+
135+
111136
#endif // _XETILE_ATTR_DEF_TD_INCLUDED_

include/imex/Dialect/XeTile/IR/XeTileOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,13 @@ def XeTile_PrefetchTileOp : XeTile_Op<"prefetch_tile", []> {
340340

341341
}];
342342

343-
let arguments = (ins XeTile:$tile);
343+
let arguments = (ins XeTile:$tile,
344+
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
345+
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
346+
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);
344347

345348
let assemblyFormat = [{
346-
$tile attr-dict `:` qualified(type($tile))
349+
$tile attr-dict `:` qualified(type($tile))
347350
}];
348351
}
349352

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,28 @@ class SgInitTileOpPattern
451451
}
452452
};
453453

454+
static mlir::xegpu::CachePolicy
455+
translateCachePolicy(imex::xetile::CachePolicyAttr val) {
456+
if (!val)
457+
return mlir::xegpu::CachePolicy::CACHED;
458+
459+
switch (val.getValue()) {
460+
case imex::xetile::CachePolicy::CACHED:
461+
return mlir::xegpu::CachePolicy::CACHED;
462+
case imex::xetile::CachePolicy::UNCACHED:
463+
return mlir::xegpu::CachePolicy::UNCACHED;
464+
case imex::xetile::CachePolicy::STREAMING:
465+
return mlir::xegpu::CachePolicy::STREAMING;
466+
case imex::xetile::CachePolicy::READ_INVALIDATE:
467+
return mlir::xegpu::CachePolicy::READ_INVALIDATE;
468+
case imex::xetile::CachePolicy::WRITE_BACK:
469+
return mlir::xegpu::CachePolicy::WRITE_BACK;
470+
case imex::xetile::CachePolicy::WRITE_THROUGH:
471+
return mlir::xegpu::CachePolicy::WRITE_THROUGH;
472+
}
473+
llvm_unreachable("Invalid CachePolicy value");
474+
}
475+
454476
// It lowers a XeTile::prefetch_tile into one or more mlir::xegpu::prefetch_2d.
455477
// The adaptor will provide the set of xegpu.create_nd_desc lowered for
456478
// its input tile.
@@ -481,12 +503,14 @@ struct SgPrefetchTileOpPattern
481503
return mlir::failure();
482504
}
483505

484-
auto L1 = mlir::xegpu::CachePolicyAttr::get(
485-
op.getContext(), mlir::xegpu::CachePolicy::CACHED);
486-
auto L2 = mlir::xegpu::CachePolicyAttr::get(
487-
op.getContext(), mlir::xegpu::CachePolicy::CACHED);
488-
auto L3 = mlir::xegpu::CachePolicyAttr::get(
489-
op.getContext(), mlir::xegpu::CachePolicy::CACHED);
506+
auto getCachePolicy = [&](imex::xetile::CachePolicyAttr val) {
507+
return mlir::xegpu::CachePolicyAttr::get(op.getContext(),
508+
translateCachePolicy(val));
509+
};
510+
511+
auto L1 = getCachePolicy(op.getL1HintAttr());
512+
auto L2 = getCachePolicy(op.getL2HintAttr());
513+
auto L3 = getCachePolicy(op.getL3HintAttr());
490514

491515
for (auto tile : tiles) {
492516
rewriter.create<mlir::xegpu::PrefetchNdOp>(op.getLoc(), tile, L1, L2, L3);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu %s -verify-diagnostics -o -| FileCheck %s
2+
3+
// CHECK-LABEL: test_prefetch
4+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
5+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
6+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
7+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
8+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
9+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
10+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
11+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
12+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
13+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
14+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
15+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
16+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
17+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
18+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
19+
// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
20+
// CHECK: gpu.return
21+
gpu.module @test_kernel {
22+
gpu.func @test_prefetch(%a: memref<2x64xf16>) {
23+
%c0 = arith.constant 0 : index
24+
%0 = xetile.init_tile %a[%c0, %c0] : memref<2x64xf16> -> !xetile.tile<2x64xf16, #xetile.tile_attr<inner_blocks = [1, 16]>>
25+
xetile.prefetch_tile %0 : !xetile.tile<2x64xf16, #xetile.tile_attr<inner_blocks = [1, 16]>>
26+
xetile.prefetch_tile %0 {l1_hint = #xetile.cache_hint<uncached>, l3_hint = #xetile.cache_hint<streaming>} : !xetile.tile<2x64xf16, #xetile.tile_attr<inner_blocks = [1, 16]>>
27+
gpu.return
28+
}
29+
}

0 commit comments

Comments
 (0)