Skip to content

Commit 3bd3e32

Browse files
PMylonmakslevental
andauthored
[AMD] Use lld library API for linking (#7548)
This commit replaces shell out to ROCm's lld by directly calling the lld library API. --------- Co-authored-by: Maksim Levental <[email protected]>
1 parent 0daeb4f commit 3bd3e32

File tree

3 files changed

+41
-29
lines changed

3 files changed

+41
-29
lines changed

third_party/amd/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
33
add_subdirectory(include)
44
add_subdirectory(lib)
55
if(TRITON_BUILD_PYTHON_MODULE)
6+
find_package(LLD REQUIRED CONFIG PATHS "${MLIR_DIR}/../lld" NO_DEFAULT_PATH)
7+
include_directories(${LLD_INCLUDE_DIRS})
8+
message(STATUS "Found LLD distro-package @ ${LLD_DIR} and LLD include dirs @ ${LLD_INCLUDE_DIRS}")
69
add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM)
7-
target_link_libraries(TritonAMD PRIVATE Python3::Module pybind11::headers)
10+
target_link_libraries(TritonAMD PRIVATE Python3::Module pybind11::headers lldCommon lldELF)
811
endif()
912
if(TRITON_BUILD_UT)
1013
add_subdirectory(unittest)

third_party/amd/backend/compiler.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import hashlib
88
import tempfile
99
import re
10-
import subprocess
1110
import functools
1211
import warnings
1312
from pathlib import Path
@@ -176,26 +175,6 @@ def get_arg_specialization(arg, ty, **kwargs):
176175
ret += "S"
177176
return ret
178177

179-
@staticmethod
180-
def path_to_rocm_lld():
181-
# Check env path for ld.lld
182-
lld_env_path = knobs.amd.lld_path
183-
if lld_env_path is not None:
184-
lld = Path(lld_env_path)
185-
if lld.is_file():
186-
return lld
187-
# Check backend for ld.lld (used for pytorch wheels)
188-
lld = Path(__file__).parent / "llvm/bin/ld.lld"
189-
if lld.is_file():
190-
return lld
191-
lld = Path("/opt/rocm/llvm/bin/ld.lld")
192-
if lld.is_file():
193-
return lld
194-
lld = Path("/usr/bin/ld.lld")
195-
if lld.is_file():
196-
return lld
197-
raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found. Set 'TRITON_HIP_LLD_PATH' to its path.")
198-
199178
@staticmethod
200179
def make_ttir(mod, metadata, options):
201180
pm = ir.pass_manager(mod.context)
@@ -434,14 +413,12 @@ def make_hsaco(src, metadata, options):
434413
if knobs.compilation.enable_asan:
435414
target_features = '+xnack'
436415
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
437-
438-
rocm_path = HIPBackend.path_to_rocm_lld()
439416
with tempfile.NamedTemporaryFile() as tmp_out:
440417
with tempfile.NamedTemporaryFile() as tmp_in:
441-
with open(tmp_in.name, 'wb') as fd_in:
418+
with open(tmp_in.name, "wb") as fd_in:
442419
fd_in.write(hsaco)
443-
subprocess.check_call([rocm_path, '-flavor', 'gnu', '-shared', tmp_in.name, '-o', tmp_out.name])
444-
with open(tmp_out.name, 'rb') as fd_out:
420+
amd.link_hsaco(tmp_in.name, tmp_out.name)
421+
with open(tmp_out.name, "rb") as fd_out:
445422
ret = fd_out.read()
446423
return ret
447424

@@ -457,5 +434,4 @@ def add_stages(self, stages, options, language):
457434

458435
@functools.lru_cache()
459436
def hash(self):
460-
version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding='utf-8')
461-
return f'{version}-{self.target}'
437+
return f'{self.target}'

third_party/amd/python/triton_amd.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "TritonAMDGPUToLLVM/Passes.h"
33
#include "TritonAMDGPUToLLVM/TargetUtils.h"
44
#include "TritonAMDGPUTransforms/Passes.h"
5+
#include "lld/Common/Driver.h"
56
#include "mlir/Pass/PassManager.h"
67
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
78
#include "passes.h"
@@ -27,6 +28,7 @@
2728
#include "llvm/Support/FileSystem.h"
2829
#include "llvm/Support/SourceMgr.h"
2930
#include "llvm/TargetParser/TargetParser.h"
31+
#include <array>
3032
#include <pybind11/pybind11.h>
3133
#include <stdexcept>
3234

@@ -110,8 +112,31 @@ void addControlConstant(llvm::Module *module, const char *name,
110112
constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
111113
constant->setVisibility(GlobalVariable::VisibilityTypes::ProtectedVisibility);
112114
}
115+
113116
} // namespace
114117

118+
LLD_HAS_DRIVER(elf)
119+
120+
static std::optional<std::string> lldInvoke(const char *inPath,
121+
const char *outPath) {
122+
// Workaround: Disable parallelism to avoid hangs caused by LLVM's thread pool
123+
// when the following code is executed in a forked child process.
124+
// Context: lld::elf::LinkerDriver::link uses parallelFor which uses the
125+
// LLVM's thread pool. During cleanup at ~TaskGroup() the child process hangs
126+
// waiting.
127+
std::array args{"ld.lld", "--threads=1", "-shared", inPath, "-o", outPath};
128+
std::string errString;
129+
llvm::raw_string_ostream errStream(errString);
130+
auto lldRes = lld::lldMain(args, llvm::outs(), llvm::errs(),
131+
{{lld::Gnu, &lld::elf::link}});
132+
bool noErrors = (!lldRes.retCode && lldRes.canRunAgain);
133+
if (!noErrors) {
134+
errStream.flush();
135+
return errString;
136+
}
137+
return {};
138+
}
139+
115140
void init_triton_amd(py::module &&m) {
116141
m.doc() = "Python bindings to the AMD Triton backend";
117142

@@ -305,6 +330,14 @@ void init_triton_amd(py::module &&m) {
305330
}
306331
});
307332

333+
m.def("link_hsaco",
334+
[](const std::string &inPath, const std::string &outPath) {
335+
if (auto errString = lldInvoke(inPath.c_str(), outPath.c_str()))
336+
throw std::runtime_error("LLD failed to link hsaco source " +
337+
inPath + " into object file " + outPath +
338+
" because " + errString.value());
339+
});
340+
308341
m.def("add_scalarize_packed_fops_llvm_pass", [](llvm::Function *fn) {
309342
mlir::triton::AMD::runScalarizePackedFOpsPass(*fn);
310343
});

0 commit comments

Comments
 (0)