66#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
77#include " mlir/Pass/Pass.h"
88#include " triton/Conversion/TritonGPUToLLVM/Utility.h"
9- #include " llvm/TargetParser/TargetParser.h"
109
1110namespace mlir ::triton {
1211#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS
@@ -221,12 +220,13 @@ struct InstructionSchedHintsRewriter
221220 std::transform (variant.begin (), variant.end (), variant.begin (),
222221 [](unsigned char c) { return std::tolower (c); });
223222
224- this ->schedulingType = llvm::StringSwitch<SchedulingType>(variant)
225- .Case (" none" , SchedulingType::NONE)
226- .Case (" iglp0" , SchedulingType::IGLP0)
227- .Case (" iglp1" , SchedulingType::IGLP1)
228- .Case (" ck_v3" , SchedulingType::CK_V3)
229- .Default (SchedulingType::UNKNOWN);
223+ this ->schedulingType =
224+ llvm::StringSwitch<SchedulingType>(variant)
225+ .Case (" none" , SchedulingType::NONE)
226+ .Case (" llvm-iglp-0" , SchedulingType::LLVM_IGLP_0)
227+ .Case (" llvm-iglp-1" , SchedulingType::LLVM_IGLP_1)
228+ .Case (" local-prefetch" , SchedulingType::LOCAL_PREFETCH)
229+ .Default (SchedulingType::UNKNOWN);
230230
231231 if (this ->numStages < 2 ) {
232232 this ->schedulingType = SchedulingType::NONE;
@@ -237,26 +237,24 @@ struct InstructionSchedHintsRewriter
237237
238238 enum class SchedulingType : uint32_t {
239239 NONE = 0 ,
240- IGLP0 ,
241- IGLP1 ,
242- CK_V3 ,
240+ LLVM_IGLP_0 ,
241+ LLVM_IGLP_1 ,
242+ LOCAL_PREFETCH ,
243243 UNKNOWN
244244 };
245245
246- // This is the implementation of the CK 's V3 pipelining (see
247- // see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp).
246+ // The following is inspired by ROCm Composable Kernel library 's V3 pipelining
247+ // ( see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp).
248248 // This scheduling requires 1x register and 1x LDS buffers combined with the
249249 // local (LDS to registers) and global (HBM to registers) data prefetching.
250- // see:
251- // include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h
252- void
253- createCKV3Schedule (PatternRewriter &rewriter, Location loc,
254- triton::amdgpu::InstructionSchedHint schedHint) const {
250+ void createLocalPrefetchSchedule (
251+ PatternRewriter &rewriter, Location loc,
252+ triton::amdgpu::InstructionSchedHint schedHint) const {
255253
256254 if (!(schedHint.getIsBufferLoadsAEnabled () &&
257255 schedHint.getIsBufferLoadsBEnabled ())) {
258- LDBG (" Skipping instruction scheduling because `ck_v3 ` "
259- " scheduling can be used only with `buffer_load` instructions. " );
256+ LDBG (" skipping `local-prefetch` scheduling given it needs `buffer_load ` "
257+ " instructions" );
260258 return ;
261259 }
262260
@@ -435,8 +433,8 @@ struct InstructionSchedHintsRewriter
435433 // backend documentation.
436434 const bool limitSchedulingRange =
437435 !(schedulingType == SchedulingType::NONE ||
438- schedulingType == SchedulingType::IGLP0 ||
439- schedulingType == SchedulingType::IGLP1 );
436+ schedulingType == SchedulingType::LLVM_IGLP_0 ||
437+ schedulingType == SchedulingType::LLVM_IGLP_1 );
440438 Location loc = instructionSchedHint->getLoc ();
441439 Block *block = instructionSchedHint->getBlock ();
442440 if (limitSchedulingRange) {
@@ -448,22 +446,17 @@ struct InstructionSchedHintsRewriter
448446 rewriter.setInsertionPoint (block, std::prev (block->end ()));
449447
450448 switch (schedulingType) {
451- case SchedulingType::IGLP0:
452- [[fallthrough]];
453- case SchedulingType::IGLP1: {
449+ case SchedulingType::LLVM_IGLP_0:
450+ case SchedulingType::LLVM_IGLP_1:
454451 createIglpOpt (rewriter, loc, static_cast <int >(schedulingType) - 1 );
455452 break ;
456- }
457- case SchedulingType::CK_V3: {
458- createCKV3Schedule (rewriter, loc, instructionSchedHint);
453+ case SchedulingType::LOCAL_PREFETCH:
454+ createLocalPrefetchSchedule (rewriter, loc, instructionSchedHint);
459455 break ;
460- }
461456 case SchedulingType::NONE:
462- [[fallthrough]];
463- default : {
457+ default :
464458 break ;
465459 }
466- }
467460
468461 if (limitSchedulingRange)
469462 createSchedBarrier (rewriter, loc,
0 commit comments