Skip to content

Commit 9aa114a

Browse files
[AMD] Refactor instruction scheduling hints (#5144)
- Renamed instruction scheduling variants - Enabled `buffer-ops` for `local-prefetch` - Added documentation regarding current variants --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent efd4465 commit 9aa114a

File tree

3 files changed

+50
-37
lines changed

3 files changed

+50
-37
lines changed

test/TritonGPU/amd/amd-instruction-sched.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
2-
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
1+
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
2+
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
33
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
44
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
5-
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD
5+
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
66
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
77
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2
88

@@ -68,8 +68,8 @@ module {
6868
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
6969
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>
7070

71-
// USE_CKV3_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
72-
// USE_CKV3_GLOBAL_LOAD-SAME: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions.
71+
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
72+
// USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions
7373

7474
// LABELING_PS_1: scf.for
7575
// LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}

third_party/amd/backend/compiler.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ class HIPOptions:
5252
# instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
5353
# The option is experimental and may change at any time regarding its semantics and/or may
5454
# be gone entirely anytime.
55+
#
56+
# Current experimental scheduling variants:
57+
#
58+
# llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's
59+
# k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels".
60+
# llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's
61+
# k-loop; i.e., "interleave DS and MFMA instructions for single wave small
62+
# GEMM kernels.".
63+
# local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
64+
# Kernel library. Note, this variant requires the use of buffer load/store ops
65+
# and a special software pipelining style - i.e., 1x LDS and 1x register
66+
# prefetch buffers for each GEMM tile.
5567
instruction_sched_variant: str = 'none'
5668

5769
def __post_init__(self):
@@ -215,6 +227,7 @@ def make_ttgir(mod, metadata, options):
215227
passes.ttgpuir.add_remove_layout_conversions(pm)
216228
amd.passes.ttgpuir.add_optimize_epilogue(pm)
217229
passes.ttgpuir.add_optimize_dot_operands(pm, True)
230+
218231
if amd.has_matrix_core_feature(options.arch):
219232
assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
220233
"We used to trigger software pipelining with "
@@ -229,7 +242,14 @@ def make_ttgir(mod, metadata, options):
229242
passes.ttgpuir.add_reduce_data_duplication(pm)
230243
if amd.has_matrix_core_feature(options.arch):
231244
amd.passes.ttgpuir.add_reorder_instructions(pm)
232-
if os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1":
245+
246+
use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
247+
248+
# The `local-prefetch` scheduling variant requires turning on buffer ops.
249+
if options.instruction_sched_variant == "local-prefetch":
250+
use_buffer_ops = True
251+
252+
if use_buffer_ops:
233253
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
234254
passes.common.add_canonicalizer(pm)
235255
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm)

third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
77
#include "mlir/Pass/Pass.h"
88
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
9-
#include "llvm/TargetParser/TargetParser.h"
109

1110
namespace mlir::triton {
1211
#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS
@@ -221,12 +220,13 @@ struct InstructionSchedHintsRewriter
221220
std::transform(variant.begin(), variant.end(), variant.begin(),
222221
[](unsigned char c) { return std::tolower(c); });
223222

224-
this->schedulingType = llvm::StringSwitch<SchedulingType>(variant)
225-
.Case("none", SchedulingType::NONE)
226-
.Case("iglp0", SchedulingType::IGLP0)
227-
.Case("iglp1", SchedulingType::IGLP1)
228-
.Case("ck_v3", SchedulingType::CK_V3)
229-
.Default(SchedulingType::UNKNOWN);
223+
this->schedulingType =
224+
llvm::StringSwitch<SchedulingType>(variant)
225+
.Case("none", SchedulingType::NONE)
226+
.Case("llvm-iglp-0", SchedulingType::LLVM_IGLP_0)
227+
.Case("llvm-iglp-1", SchedulingType::LLVM_IGLP_1)
228+
.Case("local-prefetch", SchedulingType::LOCAL_PREFETCH)
229+
.Default(SchedulingType::UNKNOWN);
230230

231231
if (this->numStages < 2) {
232232
this->schedulingType = SchedulingType::NONE;
@@ -237,26 +237,24 @@ struct InstructionSchedHintsRewriter
237237

238238
enum class SchedulingType : uint32_t {
239239
NONE = 0,
240-
IGLP0,
241-
IGLP1,
242-
CK_V3,
240+
LLVM_IGLP_0,
241+
LLVM_IGLP_1,
242+
LOCAL_PREFETCH,
243243
UNKNOWN
244244
};
245245

246-
// This is the implementation of the CK's V3 pipelining (see
247-
// see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp).
246+
// The following is inspired by ROCm Composable Kernel library's V3 pipelining
247+
// (see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp).
248248
// This scheduling requires 1x register and 1x LDS buffers combined with the
249249
// local (LDS to registers) and global (HBM to registers) data prefetching.
250-
// see:
251-
// include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h
252-
void
253-
createCKV3Schedule(PatternRewriter &rewriter, Location loc,
254-
triton::amdgpu::InstructionSchedHint schedHint) const {
250+
void createLocalPrefetchSchedule(
251+
PatternRewriter &rewriter, Location loc,
252+
triton::amdgpu::InstructionSchedHint schedHint) const {
255253

256254
if (!(schedHint.getIsBufferLoadsAEnabled() &&
257255
schedHint.getIsBufferLoadsBEnabled())) {
258-
LDBG("Skipping instruction scheduling because `ck_v3` "
259-
"scheduling can be used only with `buffer_load` instructions.");
256+
LDBG("skipping `local-prefetch` scheduling given it needs `buffer_load` "
257+
"instructions");
260258
return;
261259
}
262260

@@ -435,8 +433,8 @@ struct InstructionSchedHintsRewriter
435433
// backend documentation.
436434
const bool limitSchedulingRange =
437435
!(schedulingType == SchedulingType::NONE ||
438-
schedulingType == SchedulingType::IGLP0 ||
439-
schedulingType == SchedulingType::IGLP1);
436+
schedulingType == SchedulingType::LLVM_IGLP_0 ||
437+
schedulingType == SchedulingType::LLVM_IGLP_1);
440438
Location loc = instructionSchedHint->getLoc();
441439
Block *block = instructionSchedHint->getBlock();
442440
if (limitSchedulingRange) {
@@ -448,22 +446,17 @@ struct InstructionSchedHintsRewriter
448446
rewriter.setInsertionPoint(block, std::prev(block->end()));
449447

450448
switch (schedulingType) {
451-
case SchedulingType::IGLP0:
452-
[[fallthrough]];
453-
case SchedulingType::IGLP1: {
449+
case SchedulingType::LLVM_IGLP_0:
450+
case SchedulingType::LLVM_IGLP_1:
454451
createIglpOpt(rewriter, loc, static_cast<int>(schedulingType) - 1);
455452
break;
456-
}
457-
case SchedulingType::CK_V3: {
458-
createCKV3Schedule(rewriter, loc, instructionSchedHint);
453+
case SchedulingType::LOCAL_PREFETCH:
454+
createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint);
459455
break;
460-
}
461456
case SchedulingType::NONE:
462-
[[fallthrough]];
463-
default: {
457+
default:
464458
break;
465459
}
466-
}
467460

468461
if (limitSchedulingRange)
469462
createSchedBarrier(rewriter, loc,

0 commit comments

Comments
 (0)