@@ -103,15 +103,18 @@ std::string getRegisterSizeCode(int size, bool is_float) {
103103}
104104
105105Value createCachePolicy (triton::EvictionPolicy opEvict,
106- ConversionPatternRewriter &rewriter, Location loc) {
106+ ConversionPatternRewriter &rewriter, Location loc,
107+ int computeCapability) {
107108 // Emit createpolicy.fractional.L2::policy.b64 xx 1.0
108109 PTXBuilder ptxBuilder;
109110 const bool hasL2EvictPolicy =
110111 opEvict == triton::EvictionPolicy::EVICT_FIRST ||
111112 opEvict == triton::EvictionPolicy::EVICT_LAST;
112113 Value policyRet;
113114
114- if (hasL2EvictPolicy) {
115+ const bool hardwareSupport = computeCapability >= 80 ;
116+
117+ if (hasL2EvictPolicy && hardwareSupport) {
115118 auto &policy =
116119 ptxBuilder.create <>(" createpolicy.fractional" )
117120 ->o (" L2::evict_first" ,
@@ -170,10 +173,11 @@ struct LoadStoreConversionBase {
170173struct LoadOpConversion : public ConvertOpToLLVMPattern <triton::LoadOp>,
171174 public LoadStoreConversionBase {
172175 LoadOpConversion (LLVMTypeConverter &converter,
173- const NVIDIA::TargetInfo &targetInfo,
176+ const NVIDIA::TargetInfo &targetInfo, int computeCapability,
174177 ModuleAxisInfoAnalysis &axisAnalysisPass,
175178 PatternBenefit benefit)
176179 : ConvertOpToLLVMPattern<triton::LoadOp>(converter, benefit),
180+ computeCapability (computeCapability),
177181 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
178182
179183 LogicalResult
@@ -336,7 +340,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
336340 ptxBuilder.newAddrOperand (ptrElems[vecStart], " l" , in_off);
337341
338342 // Create L2 cache policy register if needed
339- Value l2PolicyReg = createCachePolicy (op.getEvict (), rewriter, loc);
343+ Value l2PolicyReg =
344+ createCachePolicy (op.getEvict (), rewriter, loc, computeCapability);
340345
341346 // Define the instruction opcode
342347 auto &ld = ptxBuilder.create <>(" ld" )
@@ -397,15 +402,18 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
397402 rewriter.replaceOp (op, {resultStruct});
398403 return success ();
399404 }
405+
406+ int computeCapability;
400407};
401408
402409struct StoreOpConversion : public ConvertOpToLLVMPattern <triton::StoreOp>,
403410 public LoadStoreConversionBase {
404411 StoreOpConversion (LLVMTypeConverter &converter,
405- const NVIDIA::TargetInfo &targetInfo,
412+ const NVIDIA::TargetInfo &targetInfo, int computeCapability,
406413 ModuleAxisInfoAnalysis &axisAnalysisPass,
407414 PatternBenefit benefit)
408415 : ConvertOpToLLVMPattern<triton::StoreOp>(converter, benefit),
416+ computeCapability (computeCapability),
409417 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
410418
411419 LogicalResult
@@ -519,7 +527,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
519527 ptxBuilder.newAddrOperand (ptrElems[vecStart], " l" , in_off);
520528
521529 // Create L2 cache policy register if needed
522- Value l2PolicyReg = createCachePolicy (op.getEvict (), rewriter, loc);
530+ Value l2PolicyReg =
531+ createCachePolicy (op.getEvict (), rewriter, loc, computeCapability);
523532
524533 auto &ptxStoreInstr =
525534 ptxBuilder.create <>(" st" )
@@ -551,6 +560,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
551560 rewriter.eraseOp (op);
552561 return success ();
553562 }
563+
564+ int computeCapability;
554565};
555566
556567void createBarrier (ConversionPatternRewriter &rewriter, Location loc,
@@ -1866,11 +1877,13 @@ struct TMAStoreWaitOpConversion
18661877
18671878void mlir::triton::NVIDIA::populateLoadStoreOpToLLVMPatterns (
18681879 LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
1869- RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis ,
1870- PatternBenefit benefit) {
1880+ int computeCapability, RewritePatternSet &patterns ,
1881+ ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
18711882 patterns.add <AsyncCopyGlobalToLocalOpConversion, AtomicCASOpConversion,
1872- AtomicRMWOpConversion, LoadOpConversion, StoreOpConversion>(
1873- typeConverter, targetInfo, axisInfoAnalysis, benefit);
1883+ AtomicRMWOpConversion>(typeConverter, targetInfo,
1884+ axisInfoAnalysis, benefit);
1885+ patterns.add <LoadOpConversion, StoreOpConversion>(
1886+ typeConverter, targetInfo, computeCapability, axisInfoAnalysis, benefit);
18741887 patterns.add <AsyncCommitGroupOpConversion>(typeConverter, benefit);
18751888 patterns.add <AsyncWaitOpConversion>(typeConverter, benefit);
18761889 patterns.add <AsyncTMACopyGlobalToLocalOpConversion,
0 commit comments