Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ARG ROCM_BUILD_JOB
ARG ROCM_BUILD_NUM

# Install system GCC and C++ libraries.
RUN yum install -y gcc-c++.x86_64
RUN yum install -y gcc-c++.x86_64 patchelf

RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
Expand Down
52 changes: 52 additions & 0 deletions jax_plugins/rocm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,62 @@ def _get_library_path():
return None


def set_rocm_paths(path):
rocm_lib = None
try:
import rocm
rocm_lib = os.path.join(rocm.__path__[0], "lib")
except ImportError:
# find python site-packages
sp = path.parent.parent.parent
maybe_rocm_lib = os.path.join(sp, "rocm/lib")
if os.path.exists(maybe_rocm_lib):
rocm_lib = maybe_rocm_lib

if not rocm_lib:
logger.info("No ROCm wheel installation found")
return
else:
logger.info("ROCm wheel install found at %r" % rocm_lib)

bitcode_path = ""
lld_path = ""

for root, dirs, files in os.walk(os.path.join(rocm_lib, "llvm")):
# look for ld.lld and ocml.bc
for f in files:
if f == "ocml.bc":
bitcode_path = root
if f == "ld.lld":
# amd backend needs the directory not the full path to binary
lld_path = root

if bitcode_path and lld_path:
break


if not bitcode_path:
logger.warning("jax_rocm60_plugin couldn't locate amdgpu bitcode")
else:
logger.info("jax_rocm60_plugin using bitcode found at %r", bitcode_path)

if not lld_path:
logger.warning("jax_rocm60_plugin couldn't locate amdgpu ld.lld")
else:
logger.info("jax_rocm60_plugin using ld.lld found at %r", lld_path)

os.environ["JAX_ROCM_PLUGIN_INTERNAL_BITCODE_PATH"] = bitcode_path
os.environ["HIP_DEVICE_LIB_PATH"] = bitcode_path
os.environ["JAX_ROCM_PLUGIN_INTERNAL_LLD_PATH"] = lld_path


def initialize():
path = _get_library_path()
if path is None:
return

set_rocm_paths(path)

options = xla_client.generate_pjrt_gpu_plugin_options()
options["platform_name"] = "ROCM"
c_api = xb.register_plugin(
Expand Down
16 changes: 16 additions & 0 deletions jax_plugins/rocm/plugin_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,20 @@ def has_ext_modules(self):
},
zip_safe=False,
distclass=BinaryDistribution,
extras_require={
"with_rocm": [
"amd_rocm_hip_runtime_devel_instinct",
"amd_rocm_hip_runtime_instinct",
"amd_hipblas_instinct",
"amd_hipsparse_instinct",
"amd_hipsolver_instinct",
"amd_miopen_hip_instinct",
"amd_rocm_llvm_instinct",
"amd_rocm_language_runtime_instinct",
"amd_rccl_instinct",
"amd_hipfft_instinct",
"amd_rocm_device_libs_instinct",
"amd_hipsparselt_instinct",
],
},
)
34 changes: 34 additions & 0 deletions jaxlib/tools/build_gpu_kernels_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
import os
import pathlib
import subprocess
import tempfile

from bazel_tools.tools.python.runfiles import runfiles
Expand Down Expand Up @@ -151,6 +152,39 @@ def prepare_wheel_rocm(
],
)

# NOTE(mrodden): this is a hack to change/set rpath values
# in the shared objects that are produced by the bazel build
# before they get pulled into the wheel build process.
# we have to do this change here because setting rpath
# using bazel requires the rpath to be valid during the build
# which won't be correct until we make changes to
# the xla/tsl/jax plugin build

try:
subprocess.check_output(["which", "patchelf"])
except subprocess.CalledProcessError as ex:
mesg = (
"rocm plugin and kernel wheel builds require patchelf. "
"please install 'patchelf' and run again"
)
raise Exception(mesg) from ex

files = [
f"_blas.{pyext}",
f"_linalg.{pyext}",
f"_prng.{pyext}",
f"_solver.{pyext}",
f"_sparse.{pyext}",
f"_rnn.{pyext}",
f"_triton.{pyext}",
f"rocm_plugin_extension.{pyext}",
]
runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib'
# patchelf --force-rpath --set-rpath $RUNPATH $so
for f in files:
so_path = os.path.join(plugin_dir, f)
subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, so_path])

# Build wheel for cuda kernels
if args.enable_rocm:
tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin")
Expand Down
24 changes: 24 additions & 0 deletions jaxlib/tools/build_gpu_plugin_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
import os
import pathlib
import subprocess
import tempfile

from bazel_tools.tools.python.runfiles import runfiles
Expand Down Expand Up @@ -141,6 +142,29 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version):
dst_filename="xla_rocm_plugin.so",
)

# NOTE(mrodden): this is a hack to change/set rpath values
# in the shared objects that are produced by the bazel build
# before they get pulled into the wheel build process.
# we have to do this change here because setting rpath
# using bazel requires the rpath to be valid during the build
# which won't be correct until we make changes to
# the xla/tsl/jax plugin build

try:
subprocess.check_output(["which", "patchelf"])
except subprocess.CalledProcessError as ex:
mesg = (
"rocm plugin and kernel wheel builds require patchelf. "
"please install 'patchelf' and run again"
)
raise Exception(mesg) from ex

shared_obj_path = os.path.join(plugin_dir, "xla_rocm_plugin.so")
runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib'
# patchelf --force-rpath --set-rpath $RUNPATH $so
subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, shared_obj_path])



tmpdir = None
sources_path = args.sources_path
Expand Down
Loading