Skip to content

Commit 732af71

Browse files
committed
Add support for ROCm wheel based install
Requires 0e599b747dbd3fc6 from rocm/xla This also requires some changes on the XLA side for the paths and such to have any effect. The plugin init now looks for a `rocm` python package install and extracts ROCm toolkit paths from the python packages that it finds. We hand these into the XLA portions of the plugin via environment variables to avoid changing any interfaces like protobuf or PJRT C APIs. We also have to patch the rpath in the shared object files included in the plugin and kernel wheels so they look relative to their install path just like the cuda based plugin objects do. Some other changes are fixing missing dynamic link libraries and also adding an optional feature target to pull in rocm python dependencies for the plugin. (cherry picked from commit 61b9046)
1 parent ef1d561 commit 732af71

File tree

5 files changed

+128
-1
lines changed

5 files changed

+128
-1
lines changed

build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ARG ROCM_BUILD_JOB
55
ARG ROCM_BUILD_NUM
66

77
# Install system GCC and C++ libraries.
8-
RUN yum install -y gcc-c++.x86_64
8+
RUN yum install -y gcc-c++.x86_64 patchelf
99

1010
RUN --mount=type=cache,target=/var/cache/dnf \
1111
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \

jax_plugins/rocm/__init__.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,62 @@ def _get_library_path():
7676
return None
7777

7878

79+
def set_rocm_paths(path):
80+
rocm_lib = None
81+
try:
82+
import rocm
83+
rocm_lib = os.path.join(rocm.__path__[0], "lib")
84+
except ImportError:
85+
# find python site-packages
86+
sp = path.parent.parent.parent
87+
maybe_rocm_lib = os.path.join(sp, "rocm/lib")
88+
if os.path.exists(maybe_rocm_lib):
89+
rocm_lib = maybe_rocm_lib
90+
91+
if not rocm_lib:
92+
logger.info("No ROCm wheel installation found")
93+
return
94+
else:
95+
logger.info("ROCm wheel install found at %r" % rocm_lib)
96+
97+
bitcode_path = ""
98+
lld_path = ""
99+
100+
for root, dirs, files in os.walk(os.path.join(rocm_lib, "llvm")):
101+
# look for ld.lld and ocml.bc
102+
for f in files:
103+
if f == "ocml.bc":
104+
bitcode_path = root
105+
if f == "ld.lld":
106+
# amd backend needs the directory not the full path to binary
107+
lld_path = root
108+
109+
if bitcode_path and lld_path:
110+
break
111+
112+
113+
if not bitcode_path:
114+
logger.warning("jax_rocm60_plugin couldn't locate amdgpu bitcode")
115+
else:
116+
logger.info("jax_rocm60_plugin using bitcode found at %r", bitcode_path)
117+
118+
if not lld_path:
119+
logger.warning("jax_rocm60_plugin couldn't locate amdgpu ld.lld")
120+
else:
121+
logger.info("jax_rocm60_plugin using ld.lld found at %r", lld_path)
122+
123+
os.environ["JAX_ROCM_PLUGIN_INTERNAL_BITCODE_PATH"] = bitcode_path
124+
os.environ["HIP_DEVICE_LIB_PATH"] = bitcode_path
125+
os.environ["JAX_ROCM_PLUGIN_INTERNAL_LLD_PATH"] = lld_path
126+
127+
79128
def initialize():
80129
path = _get_library_path()
81130
if path is None:
82131
return
132+
133+
set_rocm_paths(path)
134+
83135
options = xla_client.generate_pjrt_gpu_plugin_options()
84136
options["platform_name"] = "ROCM"
85137
c_api = xb.register_plugin(

jax_plugins/rocm/plugin_setup.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,20 @@ def has_ext_modules(self):
6767
},
6868
zip_safe=False,
6969
distclass=BinaryDistribution,
70+
extras_require={
71+
"with_rocm": [
72+
"amd_rocm_hip_runtime_devel_instinct",
73+
"amd_rocm_hip_runtime_instinct",
74+
"amd_hipblas_instinct",
75+
"amd_hipsparse_instinct",
76+
"amd_hipsolver_instinct",
77+
"amd_miopen_hip_instinct",
78+
"amd_rocm_llvm_instinct",
79+
"amd_rocm_language_runtime_instinct",
80+
"amd_rccl_instinct",
81+
"amd_hipfft_instinct",
82+
"amd_rocm_device_libs_instinct",
83+
"amd_hipsparselt_instinct",
84+
],
85+
},
7086
)

jaxlib/tools/build_gpu_kernels_wheel.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import functools
2222
import os
2323
import pathlib
24+
import subprocess
2425
import tempfile
2526

2627
from bazel_tools.tools.python.runfiles import runfiles
@@ -151,6 +152,40 @@ def prepare_wheel_rocm(
151152
],
152153
)
153154

155+
# NOTE(mrodden): this is a hack to change/set rpath values
156+
# in the shared objects that are produced by the bazel build
157+
# before they get pulled into the wheel build process.
158+
# we have to do this change here because setting rpath
159+
# using bazel requires the rpath to be valid during the build
160+
# which won't be correct until we make changes to
161+
# the xla/tsl/jax plugin build
162+
163+
try:
164+
subprocess.check_output(["which", "patchelf"])
165+
except subprocess.CalledProcessError as ex:
166+
mesg = (
167+
"rocm plugin and kernel wheel builds require patchelf. "
168+
"please install 'patchelf' and run again"
169+
)
170+
raise Exception(mesg) from ex
171+
172+
files = [
173+
f"_blas.{pyext}",
174+
f"_linalg.{pyext}",
175+
f"_prng.{pyext}",
176+
f"_solver.{pyext}",
177+
f"_sparse.{pyext}",
178+
f"_hybrid.{pyext}",
179+
f"_rnn.{pyext}",
180+
f"_triton.{pyext}",
181+
f"rocm_plugin_extension.{pyext}",
182+
]
183+
runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib'
184+
# patchelf --force-rpath --set-rpath $RUNPATH $so
185+
for f in files:
186+
so_path = os.path.join(plugin_dir, f)
187+
subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, so_path])
188+
154189
# Build wheel for cuda kernels
155190
if args.enable_rocm:
156191
tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin")

jaxlib/tools/build_gpu_plugin_wheel.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import functools
2222
import os
2323
import pathlib
24+
import subprocess
2425
import tempfile
2526

2627
from bazel_tools.tools.python.runfiles import runfiles
@@ -141,6 +142,29 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version):
141142
dst_filename="xla_rocm_plugin.so",
142143
)
143144

145+
# NOTE(mrodden): this is a hack to change/set rpath values
146+
# in the shared objects that are produced by the bazel build
147+
# before they get pulled into the wheel build process.
148+
# we have to do this change here because setting rpath
149+
# using bazel requires the rpath to be valid during the build
150+
# which won't be correct until we make changes to
151+
# the xla/tsl/jax plugin build
152+
153+
try:
154+
subprocess.check_output(["which", "patchelf"])
155+
except subprocess.CalledProcessError as ex:
156+
mesg = (
157+
"rocm plugin and kernel wheel builds require patchelf. "
158+
"please install 'patchelf' and run again"
159+
)
160+
raise Exception(mesg) from ex
161+
162+
shared_obj_path = os.path.join(plugin_dir, "xla_rocm_plugin.so")
163+
runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib'
164+
# patchelf --force-rpath --set-rpath $RUNPATH $so
165+
subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, shared_obj_path])
166+
167+
144168

145169
tmpdir = None
146170
sources_path = args.sources_path

0 commit comments

Comments
 (0)