Skip to content

Commit fcf33a3

Browse files
authored
(feat) add L2 cache hints to triton (#6278)
Attempt at fixing triton-lang/triton#3438. Continues work from triton-lang/triton#3470. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [X] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 80a80cf commit fcf33a3

File tree

2 files changed

+84
-19
lines changed

2 files changed

+84
-19
lines changed

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,46 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
8787
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
8888
// CHECK-LABEL: store_with_cache_attr
8989
tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
90-
// CHECK: llvm.inline_asm
91-
// CHECK-SAME: st.global.L1::evict_last.b32
92-
// CHECK: llvm.inline_asm
93-
// CHECK-SAME: st.global.L1::evict_last.b32
90+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;"
91+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "@$3 st.global.L1::evict_last.L2::cache_hint.b32 [ $1 + 0 ], { $0 }, $2;"
92+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;"
93+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "@$3 st.global.L1::evict_last.L2::cache_hint.b32 [ $1 + 0 ], { $0 }, $2;"
9494
tt.store %a_ptr_init, %cst_0, %cst evictionPolicy = evict_last cacheModifier = ca : tensor<256x!tt.ptr<f32>, #blocked0>
9595
tt.return
9696
}
9797
}
9898

9999
// -----
100100

101+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
102+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
103+
// CHECK-LABEL: load_with_l2_cache_hint
104+
tt.func @load_with_l2_cache_hint(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
105+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;"
106+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;"
107+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;"
108+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;"
109+
%1 = tt.load %a_ptr_init, %cst, %cst_0 evictionPolicy = evict_first : tensor<256x!tt.ptr<f32>, #blocked0>
110+
tt.return
111+
}
112+
}
113+
114+
// -----
115+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
116+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
117+
// CHECK-LABEL: store_with_l2_cache_hint
118+
tt.func @store_with_l2_cache_hint(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
119+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;"
120+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "@$3 st.global.L1::evict_last.L2::cache_hint.b32 [ $1 + 0 ], { $0 }, $2;"
121+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;"
122+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att {{.*}} "@$3 st.global.L1::evict_last.L2::cache_hint.b32 [ $1 + 0 ], { $0 }, $2;"
123+
tt.store %a_ptr_init, %cst_0, %cst evictionPolicy = evict_last : tensor<256x!tt.ptr<f32>, #blocked0>
124+
tt.return
125+
}
126+
}
127+
128+
// -----
129+
101130
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
102131
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} {
103132
// CHECK-LABEL: global_load_store_no_vec

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,37 @@ std::string getRegisterSizeCode(int size, bool is_float) {
9999
}
100100
}
101101

102+
Value createCachePolicy(triton::EvictionPolicy opEvict,
103+
ConversionPatternRewriter &rewriter, Location loc) {
104+
// Emit createpolicy.fractional.L2::policy.b64 xx 1.0
105+
PTXBuilder ptxBuilder;
106+
const bool hasL2EvictPolicy =
107+
opEvict == triton::EvictionPolicy::EVICT_FIRST ||
108+
opEvict == triton::EvictionPolicy::EVICT_LAST;
109+
Value policyRet;
110+
111+
if (hasL2EvictPolicy) {
112+
auto &policy =
113+
ptxBuilder.create<>("createpolicy.fractional")
114+
->o("L2::evict_first",
115+
opEvict == triton::EvictionPolicy::EVICT_FIRST)
116+
.o("L2::evict_last", opEvict == triton::EvictionPolicy::EVICT_LAST)
117+
.b(64);
118+
119+
const std::string writeConstraint = "=l";
120+
// prepare asm operands
121+
auto *dstOpr = ptxBuilder.newOperand(writeConstraint, /*init=*/true);
122+
std::string fractionStr = "1.0";
123+
auto *fractionOpr = ptxBuilder.newConstantOperand(fractionStr);
124+
policy(dstOpr, fractionOpr);
125+
126+
Type policyRetTy = rewriter.getI64Type();
127+
policyRet = ptxBuilder.launch(rewriter, loc, policyRetTy);
128+
}
129+
130+
return policyRet;
131+
}
132+
102133
// Contains some helper functions for both Load and Store conversions.
103134
struct LoadStoreConversionBase {
104135
explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo,
@@ -246,10 +277,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
246277
const size_t movWidth = width < 16 ? 16 : width;
247278
assert(wordNElems * nWords * numVecs == numElems);
248279

249-
// TODO(Superjomn) Add cache policy fields to StoreOp.
250-
// TODO(Superjomn) Deal with cache policy here.
251-
const bool hasL2EvictPolicy = false;
252-
253280
PTXBuilder ptxBuilder;
254281

255282
Value pred = mask ? maskElems[vecStart] : Value{};
@@ -305,6 +332,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
305332
auto *addrOpr =
306333
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
307334

335+
// Create L2 cache policy register if needed
336+
Value l2PolicyReg = createCachePolicy(op.getEvict(), rewriter, loc);
337+
308338
// Define the instruction opcode
309339
auto &ld = ptxBuilder.create<>("ld")
310340
->o("volatile", op.getIsVolatile())
@@ -315,15 +345,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
315345
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
316346
.o("L1::evict_last",
317347
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
318-
.o("L1::cache_hint", hasL2EvictPolicy)
348+
.o("L2::cache_hint", l2PolicyReg != Value())
319349
.v(nWords)
320350
.b(width);
321351

322-
PTXBuilder::Operand *evictOpr{};
323-
324-
// Here lack a mlir::Value to bind to this operation, so disabled.
325-
// if (has_l2_evict_policy)
326-
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
352+
PTXBuilder::Operand *evictOpr = nullptr;
353+
if (l2PolicyReg)
354+
evictOpr = ptxBuilder.newOperand(l2PolicyReg, "l");
327355

328356
if (!evictOpr)
329357
ld(dstsOpr, addrOpr).maybePredicate(pred, "b");
@@ -336,10 +364,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
336364
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
337365
: retTys[0];
338366

339-
// TODO: if (has_l2_evict_policy)
340-
// auto asmDialectAttr =
341-
// LLVM::AsmDialectAttr::get(rewriter.getContext(),
342-
// LLVM::AsmDialect::AD_ATT);
343367
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
344368

345369
// Extract and store return values
@@ -492,6 +516,9 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
492516
auto *asmAddr =
493517
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
494518

519+
// Create L2 cache policy register if needed
520+
Value l2PolicyReg = createCachePolicy(op.getEvict(), rewriter, loc);
521+
495522
auto &ptxStoreInstr =
496523
ptxBuilder.create<>("st")
497524
->global()
@@ -503,9 +530,18 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
503530
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
504531
.o("L1::evict_last",
505532
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
533+
.o("L2::cache_hint", l2PolicyReg != Value())
506534
.v(nWords)
507535
.b(width);
508-
ptxStoreInstr(asmAddr, asmArgList).maybePredicate(pred, "b");
536+
537+
PTXBuilder::Operand *evictOpr = nullptr;
538+
if (l2PolicyReg)
539+
evictOpr = ptxBuilder.newOperand(l2PolicyReg, "l");
540+
541+
if (!evictOpr)
542+
ptxStoreInstr(asmAddr, asmArgList).maybePredicate(pred, "b");
543+
else
544+
ptxStoreInstr(asmAddr, asmArgList, evictOpr).maybePredicate(pred, "b");
509545

510546
auto asmReturnTy = void_ty(ctx);
511547
ptxBuilder.launch(rewriter, loc, asmReturnTy);

0 commit comments

Comments
 (0)