Skip to content

Commit 76b6977

Browse files
[NVIDIA] L2 cache hints only for sm >= 80 (#7219)
L2 cache hints are available in PTX only for sm >= 80. This PR will cause cache hints to be ignored for L2 on sm < 80. Motivation: some Inductor tests are failing on sm75 due to cache hints added in Inductor: e.g. https://github.com/pytorch/pytorch/actions/runs/15712699449/job/44296072018
1 parent e671c0f commit 76b6977

File tree

4 files changed

+39
-11
lines changed

4 files changed

+39
-11
lines changed

test/Conversion/tritongpu_to_llvm_volta.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,16 @@ module attributes {"ttg.target" = "cuda:70", "ttg.num-ctas" = 1 : i32, "ttg.num-
1616
tt.return
1717
}
1818
}
19+
20+
// -----
21+
22+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
23+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
24+
// CHECK-LABEL: store_with_cache_attr
25+
tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
26+
// CHECK-NOT: createpolicy.fractional
27+
// CHECK: st.global.L1::evict_last.b32
28+
tt.store %a_ptr_init, %cst_0, %cst evictionPolicy = evict_last cacheModifier = ca : tensor<256x!tt.ptr<f32>, #blocked0>
29+
tt.return
30+
}
31+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,18 @@ std::string getRegisterSizeCode(int size, bool is_float) {
103103
}
104104

105105
Value createCachePolicy(triton::EvictionPolicy opEvict,
106-
ConversionPatternRewriter &rewriter, Location loc) {
106+
ConversionPatternRewriter &rewriter, Location loc,
107+
int computeCapability) {
107108
// Emit createpolicy.fractional.L2::policy.b64 xx 1.0
108109
PTXBuilder ptxBuilder;
109110
const bool hasL2EvictPolicy =
110111
opEvict == triton::EvictionPolicy::EVICT_FIRST ||
111112
opEvict == triton::EvictionPolicy::EVICT_LAST;
112113
Value policyRet;
113114

114-
if (hasL2EvictPolicy) {
115+
const bool hardwareSupport = computeCapability >= 80;
116+
117+
if (hasL2EvictPolicy && hardwareSupport) {
115118
auto &policy =
116119
ptxBuilder.create<>("createpolicy.fractional")
117120
->o("L2::evict_first",
@@ -170,10 +173,11 @@ struct LoadStoreConversionBase {
170173
struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
171174
public LoadStoreConversionBase {
172175
LoadOpConversion(LLVMTypeConverter &converter,
173-
const NVIDIA::TargetInfo &targetInfo,
176+
const NVIDIA::TargetInfo &targetInfo, int computeCapability,
174177
ModuleAxisInfoAnalysis &axisAnalysisPass,
175178
PatternBenefit benefit)
176179
: ConvertOpToLLVMPattern<triton::LoadOp>(converter, benefit),
180+
computeCapability(computeCapability),
177181
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
178182

179183
LogicalResult
@@ -336,7 +340,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
336340
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
337341

338342
// Create L2 cache policy register if needed
339-
Value l2PolicyReg = createCachePolicy(op.getEvict(), rewriter, loc);
343+
Value l2PolicyReg =
344+
createCachePolicy(op.getEvict(), rewriter, loc, computeCapability);
340345

341346
// Define the instruction opcode
342347
auto &ld = ptxBuilder.create<>("ld")
@@ -397,15 +402,18 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
397402
rewriter.replaceOp(op, {resultStruct});
398403
return success();
399404
}
405+
406+
int computeCapability;
400407
};
401408

402409
struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
403410
public LoadStoreConversionBase {
404411
StoreOpConversion(LLVMTypeConverter &converter,
405-
const NVIDIA::TargetInfo &targetInfo,
412+
const NVIDIA::TargetInfo &targetInfo, int computeCapability,
406413
ModuleAxisInfoAnalysis &axisAnalysisPass,
407414
PatternBenefit benefit)
408415
: ConvertOpToLLVMPattern<triton::StoreOp>(converter, benefit),
416+
computeCapability(computeCapability),
409417
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
410418

411419
LogicalResult
@@ -519,7 +527,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
519527
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
520528

521529
// Create L2 cache policy register if needed
522-
Value l2PolicyReg = createCachePolicy(op.getEvict(), rewriter, loc);
530+
Value l2PolicyReg =
531+
createCachePolicy(op.getEvict(), rewriter, loc, computeCapability);
523532

524533
auto &ptxStoreInstr =
525534
ptxBuilder.create<>("st")
@@ -551,6 +560,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
551560
rewriter.eraseOp(op);
552561
return success();
553562
}
563+
564+
int computeCapability;
554565
};
555566

556567
void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
@@ -1866,11 +1877,13 @@ struct TMAStoreWaitOpConversion
18661877

18671878
void mlir::triton::NVIDIA::populateLoadStoreOpToLLVMPatterns(
18681879
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
1869-
RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis,
1870-
PatternBenefit benefit) {
1880+
int computeCapability, RewritePatternSet &patterns,
1881+
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
18711882
patterns.add<AsyncCopyGlobalToLocalOpConversion, AtomicCASOpConversion,
1872-
AtomicRMWOpConversion, LoadOpConversion, StoreOpConversion>(
1873-
typeConverter, targetInfo, axisInfoAnalysis, benefit);
1883+
AtomicRMWOpConversion>(typeConverter, targetInfo,
1884+
axisInfoAnalysis, benefit);
1885+
patterns.add<LoadOpConversion, StoreOpConversion>(
1886+
typeConverter, targetInfo, computeCapability, axisInfoAnalysis, benefit);
18741887
patterns.add<AsyncCommitGroupOpConversion>(typeConverter, benefit);
18751888
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
18761889
patterns.add<AsyncTMACopyGlobalToLocalOpConversion,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter,
4747

4848
void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
4949
const TargetInfo &targetInfo,
50+
int computeCapability,
5051
RewritePatternSet &patterns,
5152
ModuleAxisInfoAnalysis &axisInfoAnalysis,
5253
PatternBenefit benefit);

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ struct ConvertTritonGPUToLLVM
120120
populateClampFOpToLLVMPattern(typeConverter, patterns, axisInfoAnalysis,
121121
computeCapability,
122122
patternBenefitClampOptimizedPattern);
123-
populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns,
123+
populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo,
124+
computeCapability, patterns,
124125
axisInfoAnalysis, benefit);
125126
mlir::triton::populateReduceOpToLLVMPatterns(typeConverter, patterns,
126127
targetInfo, benefit);

0 commit comments

Comments
 (0)