From 8afc41c1b603fa435ef080e5229d5b8f1200bd4f Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Fri, 14 Mar 2025 14:11:55 -0500 Subject: [PATCH] Add support for ROCm wheel based install Requires 0e599b747dbd3fc6 from rocm/xla This also requires some changes on the XLA side for the paths and such to have any effect. The plugin init now looks for a `rocm` python package install and extracts ROCm toolkit paths from the python packages that it finds. We hand these into the XLA portions of the plugin via environment variables to avoid changing any interfaces like protobuf or PJRT C APIs. We also have to patch the rpath in the shared object files included in the plugin and kernel wheels so they look relative to their install path just like the cuda based plugin objects do. Some other changes are fixing missing dynamic link libraries and also adding an optional feature target to pull in rocm python dependencies for the plugin. (cherry picked from commit 61b90463f0bdba0ff19ee1bcc44db754cdc5c152) --- .../Dockerfile.manylinux_2_28_x86_64.rocm | 2 +- jax_plugins/rocm/__init__.py | 52 +++++++++++++++++++ jax_plugins/rocm/plugin_setup.py | 16 ++++++ jaxlib/tools/build_gpu_kernels_wheel.py | 34 ++++++++++++ jaxlib/tools/build_gpu_plugin_wheel.py | 24 +++++++++ 5 files changed, 127 insertions(+), 1 deletion(-) diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 8665c0a470a3..99f7b9fe9b52 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -5,7 +5,7 @@ ARG ROCM_BUILD_JOB ARG ROCM_BUILD_NUM # Install system GCC and C++ libraries. -RUN yum install -y gcc-c++.x86_64 +RUN yum install -y gcc-c++.x86_64 patchelf RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 8b176b675b88..ba50bbd147cd 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -76,10 +76,62 @@ def _get_library_path(): return None +def set_rocm_paths(path): + rocm_lib = None + try: + import rocm + rocm_lib = os.path.join(rocm.__path__[0], "lib") + except ImportError: + # find python site-packages + sp = path.parent.parent.parent + maybe_rocm_lib = os.path.join(sp, "rocm/lib") + if os.path.exists(maybe_rocm_lib): + rocm_lib = maybe_rocm_lib + + if not rocm_lib: + logger.info("No ROCm wheel installation found") + return + else: + logger.info("ROCm wheel install found at %r" % rocm_lib) + + bitcode_path = "" + lld_path = "" + + for root, dirs, files in os.walk(os.path.join(rocm_lib, "llvm")): + # look for ld.lld and ocml.bc + for f in files: + if f == "ocml.bc": + bitcode_path = root + if f == "ld.lld": + # amd backend needs the directory not the full path to binary + lld_path = root + + if bitcode_path and lld_path: + break + + + if not bitcode_path: + logger.warning("jax_rocm60_plugin couldn't locate amdgpu bitcode") + else: + logger.info("jax_rocm60_plugin using bitcode found at %r", bitcode_path) + + if not lld_path: + logger.warning("jax_rocm60_plugin couldn't locate amdgpu ld.lld") + else: + logger.info("jax_rocm60_plugin using ld.lld found at %r", lld_path) + + os.environ["JAX_ROCM_PLUGIN_INTERNAL_BITCODE_PATH"] = bitcode_path + os.environ["HIP_DEVICE_LIB_PATH"] = bitcode_path + os.environ["JAX_ROCM_PLUGIN_INTERNAL_LLD_PATH"] = lld_path + + def initialize(): path = _get_library_path() if path is None: return + + set_rocm_paths(path) + options = xla_client.generate_pjrt_gpu_plugin_options() options["platform_name"] = "ROCM" c_api = xb.register_plugin( diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index a84a6b34ea48..100e0224bf09 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -67,4 +67,20 @@ def has_ext_modules(self): }, zip_safe=False, distclass=BinaryDistribution, + extras_require={ + "with_rocm": [ + "amd_rocm_hip_runtime_devel_instinct", + "amd_rocm_hip_runtime_instinct", + "amd_hipblas_instinct", + "amd_hipsparse_instinct", + "amd_hipsolver_instinct", + "amd_miopen_hip_instinct", + "amd_rocm_llvm_instinct", + "amd_rocm_language_runtime_instinct", + "amd_rccl_instinct", + "amd_hipfft_instinct", + "amd_rocm_device_libs_instinct", + "amd_hipsparselt_instinct", + ], + }, ) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 99334dca0162..0f81f286afb8 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -21,6 +21,7 @@ import functools import os import pathlib +import subprocess import tempfile from bazel_tools.tools.python.runfiles import runfiles @@ -151,6 +152,39 @@ def prepare_wheel_rocm( ], ) + # NOTE(mrodden): this is a hack to change/set rpath values + # in the shared objects that are produced by the bazel build + # before they get pulled into the wheel build process. + # we have to do this change here because setting rpath + # using bazel requires the rpath to be valid during the build + # which won't be correct until we make changes to + # the xla/tsl/jax plugin build + + try: + subprocess.check_output(["which", "patchelf"]) + except subprocess.CalledProcessError as ex: + mesg = ( + "rocm plugin and kernel wheel builds require patchelf. " + "please install 'patchelf' and run again" + ) + raise Exception(mesg) from ex + + files = [ + f"_blas.{pyext}", + f"_linalg.{pyext}", + f"_prng.{pyext}", + f"_solver.{pyext}", + f"_sparse.{pyext}", + f"_rnn.{pyext}", + f"_triton.{pyext}", + f"rocm_plugin_extension.{pyext}", + ] + runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib' + # patchelf --force-rpath --set-rpath $RUNPATH $so + for f in files: + so_path = os.path.join(plugin_dir, f) + subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, so_path]) + # Build wheel for cuda kernels if args.enable_rocm: tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 0e2bba0c74d0..e8e3e9cc2d0c 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -21,6 +21,7 @@ import functools import os import pathlib +import subprocess import tempfile from bazel_tools.tools.python.runfiles import runfiles @@ -141,6 +142,29 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): dst_filename="xla_rocm_plugin.so", ) + # NOTE(mrodden): this is a hack to change/set rpath values + # in the shared objects that are produced by the bazel build + # before they get pulled into the wheel build process. + # we have to do this change here because setting rpath + # using bazel requires the rpath to be valid during the build + # which won't be correct until we make changes to + # the xla/tsl/jax plugin build + + try: + subprocess.check_output(["which", "patchelf"]) + except subprocess.CalledProcessError as ex: + mesg = ( + "rocm plugin and kernel wheel builds require patchelf. " + "please install 'patchelf' and run again" + ) + raise Exception(mesg) from ex + + shared_obj_path = os.path.join(plugin_dir, "xla_rocm_plugin.so") + runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib' + # patchelf --force-rpath --set-rpath $RUNPATH $so + subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, shared_obj_path]) + + tmpdir = None sources_path = args.sources_path