Skip to content

Commit 4348109

Browse files
authored
[AMD] Add basic instruction scheduling control (#4770)
LLVM AMDGPU backend supports special intrinsics (https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics) as hints to influence instruction scheduling. This PR adds basic scaffolding for utilizing those intrinsics to better control instructions generated from the backend. It is meant to only target `tt.dot` operations which are often the most intensive ones and may demand fine-tuning to achieve better performance. Facilities added here are experimental and we need to iterate on it until to a good state.
1 parent 16c5b26 commit 4348109

File tree

8 files changed

+269
-0
lines changed

8 files changed

+269
-0
lines changed

third_party/amd/backend/compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ class HIPOptions:
4646
max_num_imprecise_acc_default: int = 0
4747
backend_name: str = 'hip'
4848

49+
# The following option provides hints to the AMDGPU backend regarding instruction scheduling
50+
# for all `tt.dot` operations in a kernel. The "default" variant preserves the default
51+
# instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
52+
# The option is experimental and may change at any time regarding its semantics and/or may
53+
# be gone entirely anytime.
54+
instruction_sched_variant: str = 'default'
55+
4956
def __post_init__(self):
5057
default_libdir = Path(__file__).parent / 'lib'
5158
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
@@ -174,6 +181,7 @@ def make_ttgir(mod, metadata, options):
174181
if options.num_stages == 0:
175182
amd.passes.ttgpuir.add_stream_pipeline(pm)
176183
passes.common.add_canonicalizer(pm)
184+
amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
177185
passes.ttgpuir.add_optimize_dot_operands(pm, True)
178186
passes.ttgpuir.add_remove_layout_conversions(pm)
179187
passes.ttgpuir.add_reduce_data_duplication(pm)
@@ -221,6 +229,7 @@ def make_llir(src, metadata, options):
221229
passes.common.add_canonicalizer(pm)
222230
passes.common.add_cse(pm)
223231
passes.common.add_symbol_dce(pm)
232+
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant)
224233
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
225234
passes.llvmir.add_di_scope(pm)
226235
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,24 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
3232
include "TritonAMDGPUDialect.td"
3333
include "TritonAMDGPUAttrDefs.td"
3434

35+
class TT_AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
36+
Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])> {
37+
}
38+
39+
def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
40+
let summary = "A placeholder op for instruction scheduling hints within a basic block";
41+
let description = [{
42+
A placeholder op for instruction scheduling hints applied to instructions within
43+
a basic block where the placeholder op is located. This op is primarily intended
44+
to be used to adjust instruction scheduling inside the resulting main loop
45+
of a `tt.dot` operation. It's easier to identify dot ops at a high level and, thus,
46+
to mark intended scheduling regions. The hint ops are eventually lowered
47+
into LLVM AMDGPU instruction scheduling primitives, which are meant to control
48+
how different kinds of instructions (valu/mfma, global/shared memory, etc.) should
49+
interleave for better instruction level parallelism.
50+
}];
51+
52+
let assemblyFormat = [{attr-dict}];
53+
}
54+
3555
#endif

third_party/amd/include/TritonAMDGPUToLLVM/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
3434
std::unique_ptr<OperationPass<ModuleOp>>
3535
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
3636
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();
37+
std::unique_ptr<OperationPass<ModuleOp>>
38+
createInsertInstructionSchedHintsPass();
39+
std::unique_ptr<OperationPass<ModuleOp>>
40+
createLowerInstructionSchedHintsPass(std::string variant);
3741

3842
#define GEN_PASS_REGISTRATION
3943
#include "TritonAMDGPUToLLVM/Passes.h.inc"

third_party/amd/include/TritonAMDGPUToLLVM/Passes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,24 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul
5555

5656
}
5757

58+
def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> {
59+
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
60+
let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()";
61+
62+
let dependentDialects = ["mlir::LLVM::LLVMDialect"];
63+
}
64+
65+
def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
66+
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
67+
let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")";
68+
69+
let dependentDialects = ["mlir::LLVM::LLVMDialect"];
70+
71+
let options = [
72+
Option<"variant", "variant", "std::string", /*default*/"\"default\"",
73+
"instruction scheduling variant">,
74+
];
75+
}
76+
77+
5878
#endif

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ add_triton_library(TritonAMDGPUToLLVM
1818
OptimizeLDSUsage.cpp
1919
OptimizeLDSUtility.cpp
2020
SPMDOpToLLVM.cpp
21+
SchedInstructions.cpp
2122

2223
DEPENDS
2324
TritonAMDGPUConversionPassIncGen
2425

2526
LINK_LIBS PUBLIC
2627
TritonGPUToLLVM
28+
TritonAMDGPUIR
2729
)
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
#include "TritonAMDGPUToLLVM/Passes.h"
2+
3+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
6+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
9+
namespace mlir::triton {
10+
#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS
11+
#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS
12+
#include "TritonAMDGPUToLLVM/Passes.h.inc"
13+
} // namespace mlir::triton
14+
15+
using namespace mlir;
16+
17+
namespace {
18+
19+
// The bitmask that encodes kinds of the instructions from AMD ISA.
20+
// The bitmask is used for providing instruction scheduling hints.
21+
enum InstructionKindMask {
22+
NONE = 0x0000000,
23+
ALL_ALU = 0x00000001,
24+
VALU = 0x00000002,
25+
SALU = 0x00000004,
26+
MFMA = 0x00000008,
27+
ALL_VMEM = 0x00000010,
28+
VMEM_READ = 0x00000020,
29+
VMEM_WRITE = 0x00000040,
30+
ALL_DS = 0x00000080,
31+
DS_READ = 0x00000100,
32+
DS_WRITE = 0x00000200
33+
};
34+
35+
// Create an intrinsic to control how different instruction kinds should
36+
// interleave for better ILP.
37+
void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc,
38+
InstructionKindMask maskValue, int sizeValue,
39+
int groupIdValue) {
40+
MLIRContext *ctx = rewriter.getContext();
41+
auto intrinsicName = str_attr("llvm.amdgcn.sched.group.barrier");
42+
43+
Value mask =
44+
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
45+
Value size =
46+
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(sizeValue));
47+
Value groupId = LLVM::createConstantI32(loc, rewriter,
48+
static_cast<int32_t>(groupIdValue));
49+
50+
LLVM::FastmathFlagsAttr defaultFlags{};
51+
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
52+
ValueRange{mask, size, groupId},
53+
defaultFlags);
54+
}
55+
56+
// Insert intrinsic that controls the types of instructions that may be
57+
// allowed to cross the intrinsic during instruction scheduling
58+
Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc,
59+
int64_t maskValue) {
60+
MLIRContext *ctx = rewriter.getContext();
61+
auto intrinsicName = str_attr("llvm.amdgcn.sched.barrier");
62+
LLVM::FastmathFlagsAttr defaultFlags{};
63+
64+
Value mask =
65+
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
66+
return rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
67+
ValueRange{mask}, defaultFlags);
68+
}
69+
70+
// Insert an experimental intrinsic for instruction group level parallelism.
71+
// The intrinsic takes a value that specifies the strategy.
72+
Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) {
73+
MLIRContext *ctx = rewriter.getContext();
74+
auto intrinsicName = str_attr("llvm.amdgcn.iglp.opt");
75+
LLVM::FastmathFlagsAttr defaultFlags{};
76+
Value iglpValue =
77+
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(value));
78+
return rewriter.create<LLVM::CallIntrinsicOp>(
79+
loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags);
80+
}
81+
82+
struct InstructionSchedHintsRewriter
83+
: public OpRewritePattern<triton::amdgpu::InstructionSchedHint> {
84+
85+
InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant)
86+
: OpRewritePattern(ctx) {
87+
std::transform(variant.begin(), variant.end(), variant.begin(),
88+
[](unsigned char c) { return std::tolower(c); });
89+
90+
this->schedulingType = llvm::StringSwitch<SchedulingType>(variant)
91+
.Case("default", SchedulingType::NONE)
92+
.Case("iglp0", SchedulingType::IGLP0)
93+
.Case("iglp1", SchedulingType::IGLP1)
94+
.Default(SchedulingType::UNKNOWN);
95+
}
96+
97+
enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN };
98+
99+
LogicalResult
100+
matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint,
101+
PatternRewriter &rewriter) const override {
102+
103+
if (this->schedulingType == SchedulingType::UNKNOWN) {
104+
llvm::dbgs()
105+
<< "[" << getDebugName() << "]: "
106+
<< "unknown instruction scheduling variant has been provided\n";
107+
return mlir::failure();
108+
}
109+
110+
// The switch controls whether instructions are allowed to cross the basic
111+
// block boundaries at the very top and at the very bottom. Note, this is
112+
// not supposed to be used together with IGLP OPT according to the AMDGPU
113+
// backend documentation.
114+
const bool limitSchedulingRange =
115+
!(schedulingType == SchedulingType::IGLP0 ||
116+
schedulingType == SchedulingType::IGLP1);
117+
Location loc = instructionSchedHint->getLoc();
118+
Block *block = instructionSchedHint->getBlock();
119+
if (limitSchedulingRange) {
120+
rewriter.setInsertionPointToStart(block);
121+
createSchedBarrier(rewriter, loc, InstructionKindMask::NONE);
122+
}
123+
124+
rewriter.setInsertionPoint(block, std::prev(block->end()));
125+
126+
switch (schedulingType) {
127+
case SchedulingType::IGLP0:
128+
[[fallthrough]];
129+
case SchedulingType::IGLP1: {
130+
createIglpOpt(rewriter, loc, static_cast<int>(schedulingType) - 1);
131+
break;
132+
}
133+
case SchedulingType::NONE:
134+
[[fallthrough]];
135+
default: {
136+
break;
137+
}
138+
}
139+
140+
if (limitSchedulingRange)
141+
createSchedBarrier(rewriter, loc, InstructionKindMask::NONE);
142+
143+
rewriter.eraseOp(instructionSchedHint);
144+
return mlir::success();
145+
}
146+
147+
private:
148+
SchedulingType schedulingType;
149+
};
150+
151+
struct LowerInstructionSchedHints
152+
: public triton::impl::LowerInstructionSchedHintsBase<
153+
LowerInstructionSchedHints> {
154+
155+
explicit LowerInstructionSchedHints(std::string variant) {
156+
this->variant = variant;
157+
}
158+
159+
void runOnOperation() override {
160+
MLIRContext *ctx = &getContext();
161+
ModuleOp mod = getOperation();
162+
163+
ConversionTarget target(*ctx);
164+
target.addLegalDialect<LLVM::LLVMDialect>();
165+
target.addIllegalOp<triton::amdgpu::InstructionSchedHint>();
166+
167+
RewritePatternSet patterns(ctx);
168+
patterns.add<InstructionSchedHintsRewriter>(ctx, this->variant);
169+
170+
if (failed(applyPartialConversion(getOperation(), target,
171+
std::move(patterns)))) {
172+
signalPassFailure();
173+
}
174+
}
175+
};
176+
177+
struct InsertInstructionSchedHints
178+
: public triton::impl::InsertInstructionSchedHintsBase<
179+
InsertInstructionSchedHints> {
180+
void runOnOperation() override {
181+
MLIRContext *ctx = &getContext();
182+
ModuleOp mod = getOperation();
183+
184+
mod->walk([ctx](triton::DotOp dot) {
185+
if (dyn_cast<mlir::scf::ForOp>(dot->getParentOp())) {
186+
mlir::OpBuilder rewriter(ctx);
187+
rewriter.setInsertionPointAfter(dot);
188+
rewriter.create<triton::amdgpu::InstructionSchedHint>(dot->getLoc());
189+
}
190+
});
191+
}
192+
};
193+
} // namespace
194+
195+
namespace mlir::triton {
196+
std::unique_ptr<OperationPass<ModuleOp>>
197+
createLowerInstructionSchedHintsPass(std::string variant) {
198+
return std::make_unique<LowerInstructionSchedHints>(variant);
199+
}
200+
201+
std::unique_ptr<OperationPass<ModuleOp>>
202+
createInsertInstructionSchedHintsPass() {
203+
return std::make_unique<InsertInstructionSchedHints>();
204+
}
205+
} // namespace mlir::triton

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1414
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1515
#include "mlir/Pass/Pass.h"
16+
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
1617
#include "triton/Analysis/Allocation.h"
1718
#include "triton/Analysis/AxisInfo.h"
1819
#include "triton/Analysis/Membar.h"
@@ -57,6 +58,7 @@ class TritonLLVMConversionTarget : public ConversionTarget {
5758
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
5859
addIllegalDialect<mlir::gpu::GPUDialect>();
5960
addLegalOp<mlir::UnrealizedConversionCastOp>();
61+
addLegalOp<triton::amdgpu::InstructionSchedHint>();
6062
}
6163
};
6264

third_party/amd/python/triton_amd.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
4444
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
4545
pm.addPass(createConvertBuiltinFuncToLLVMPass());
4646
});
47+
m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) {
48+
pm.addPass(createInsertInstructionSchedHintsPass());
49+
});
50+
m.def("lower_instruction_sched_hints",
51+
[](mlir::PassManager &pm, std::string variant) {
52+
pm.addPass(createLowerInstructionSchedHintsPass(variant));
53+
});
4754
m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm,
4855
const std::string &arch) {
4956
pm.addPass(

0 commit comments

Comments
 (0)