Skip to content

Commit 422e5d3

Browse files
[AMD] Fix slow compilation due to inlining print calls (#5153)
This PR disables inline of print related functions, which speeds up compilation of test_scan_layouts dramatically. --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent 03c6312 commit 422e5d3

File tree

4 files changed

+21
-4
lines changed

4 files changed

+21
-4
lines changed

python/src/llvm.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ std::string translateLLVMIRToASM(llvm::Module &module,
139139
{
140140
llvm::raw_string_ostream stream(result);
141141
llvm::buffer_ostream pstream(stream);
142-
for (llvm::Function &f : module.functions())
143-
f.addFnAttr(llvm::Attribute::AlwaysInline);
144142
llvm::legacy::PassManager pass;
145143
// emit
146144
auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile

python/test/unit/language/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2563,8 +2563,6 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25632563
@pytest.mark.parametrize("axis", [0, 1])
25642564
@pytest.mark.parametrize("add_overflow_check", [False, True])
25652565
def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path):
2566-
if add_overflow_check is True and is_hip():
2567-
pytest.skip("overflow check disabled on HIP while fixing issues")
25682566

25692567
overflow_check = """
25702568
%17 = arith.extsi %arg2 : i32 to i64

third_party/amd/backend/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ def make_llir(src, metadata, options):
342342
metadata["shared"] = src.get_int_attr("triton_gpu.shared")
343343

344344
amd.cleanup_bitcode_metadata(llvm_mod)
345+
# Disable inlining of print related functions,
346+
# because inlining of these function could slow down compilation significantly
347+
amd.disable_print_inline(llvm_mod)
345348
return str(llvm_mod)
346349

347350
@staticmethod

third_party/amd/python/triton_amd.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,24 @@ void init_triton_amd(py::module &&m) {
161161
module->eraseNamedMetadata(openclVersion);
162162
});
163163

164+
m.def("disable_print_inline", [](llvm::Module *module) {
165+
// List of functions name prefixes we want to forbid inline.
166+
std::array<const char *, 2> prefixes = {"__ockl_fprintf", "__ockl_printf"};
167+
168+
for (llvm::Function &f : module->functions()) {
169+
if (!f.hasName())
170+
continue;
171+
llvm::StringRef name = f.getName();
172+
173+
auto isNamePrefixed = [&name](const char *prefix) {
174+
return name.starts_with(prefix);
175+
};
176+
177+
if (llvm::any_of(prefixes, isNamePrefixed))
178+
f.addFnAttr(llvm::Attribute::NoInline);
179+
}
180+
});
181+
164182
m.def(
165183
"assemble_amdgcn",
166184
[](const std::string &assembly, const std::string &arch,

0 commit comments

Comments
 (0)