diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 8665c0a470a3..99f7b9fe9b52 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -5,7 +5,7 @@ ARG ROCM_BUILD_JOB ARG ROCM_BUILD_NUM # Install system GCC and C++ libraries. -RUN yum install -y gcc-c++.x86_64 +RUN yum install -y gcc-c++.x86_64 patchelf RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 8b176b675b88..ba50bbd147cd 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -76,10 +76,62 @@ def _get_library_path(): return None +def set_rocm_paths(path): + rocm_lib = None + try: + import rocm + rocm_lib = os.path.join(rocm.__path__[0], "lib") + except ImportError: + # find python site-packages + sp = path.parent.parent.parent + maybe_rocm_lib = os.path.join(sp, "rocm/lib") + if os.path.exists(maybe_rocm_lib): + rocm_lib = maybe_rocm_lib + + if not rocm_lib: + logger.info("No ROCm wheel installation found") + return + else: + logger.info("ROCm wheel install found at %r" % rocm_lib) + + bitcode_path = "" + lld_path = "" + + for root, dirs, files in os.walk(os.path.join(rocm_lib, "llvm")): + # look for ld.lld and ocml.bc + for f in files: + if f == "ocml.bc": + bitcode_path = root + if f == "ld.lld": + # amd backend needs the directory not the full path to binary + lld_path = root + + if bitcode_path and lld_path: + break + + + if not bitcode_path: + logger.warning("jax_rocm60_plugin couldn't locate amdgpu bitcode") + else: + logger.info("jax_rocm60_plugin using bitcode found at %r", bitcode_path) + + if not lld_path: + logger.warning("jax_rocm60_plugin couldn't locate amdgpu ld.lld") + else: + logger.info("jax_rocm60_plugin using ld.lld found at %r", lld_path) + + os.environ["JAX_ROCM_PLUGIN_INTERNAL_BITCODE_PATH"] = bitcode_path + os.environ["HIP_DEVICE_LIB_PATH"] = bitcode_path + os.environ["JAX_ROCM_PLUGIN_INTERNAL_LLD_PATH"] = lld_path + + def initialize(): path = _get_library_path() if path is None: return + + set_rocm_paths(path) + options = xla_client.generate_pjrt_gpu_plugin_options() options["platform_name"] = "ROCM" c_api = xb.register_plugin( diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index a84a6b34ea48..100e0224bf09 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -67,4 +67,20 @@ def has_ext_modules(self): }, zip_safe=False, distclass=BinaryDistribution, + extras_require={ + "with_rocm": [ + "amd_rocm_hip_runtime_devel_instinct", + "amd_rocm_hip_runtime_instinct", + "amd_hipblas_instinct", + "amd_hipsparse_instinct", + "amd_hipsolver_instinct", + "amd_miopen_hip_instinct", + "amd_rocm_llvm_instinct", + "amd_rocm_language_runtime_instinct", + "amd_rccl_instinct", + "amd_hipfft_instinct", + "amd_rocm_device_libs_instinct", + "amd_hipsparselt_instinct", + ], + }, ) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 99334dca0162..0f81f286afb8 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -21,6 +21,7 @@ import functools import os import pathlib +import subprocess import tempfile from bazel_tools.tools.python.runfiles import runfiles @@ -151,6 +152,39 @@ def prepare_wheel_rocm( ], ) + # NOTE(mrodden): this is a hack to change/set rpath values + # in the shared objects that are produced by the bazel build + # before they get pulled into the wheel build process. + # we have to do this change here because setting rpath + # using bazel requires the rpath to be valid during the build + # which won't be correct until we make changes to + # the xla/tsl/jax plugin build + + try: + subprocess.check_output(["which", "patchelf"]) + except subprocess.CalledProcessError as ex: + mesg = ( + "rocm plugin and kernel wheel builds require patchelf. " + "please install 'patchelf' and run again" + ) + raise Exception(mesg) from ex + + files = [ + f"_blas.{pyext}", + f"_linalg.{pyext}", + f"_prng.{pyext}", + f"_solver.{pyext}", + f"_sparse.{pyext}", + f"_rnn.{pyext}", + f"_triton.{pyext}", + f"rocm_plugin_extension.{pyext}", + ] + runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib' + # patchelf --force-rpath --set-rpath $RUNPATH $so + for f in files: + so_path = os.path.join(plugin_dir, f) + subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, so_path]) + # Build wheel for cuda kernels if args.enable_rocm: tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 0e2bba0c74d0..e8e3e9cc2d0c 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -21,6 +21,7 @@ import functools import os import pathlib +import subprocess import tempfile from bazel_tools.tools.python.runfiles import runfiles @@ -141,6 +142,29 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): dst_filename="xla_rocm_plugin.so", ) + # NOTE(mrodden): this is a hack to change/set rpath values + # in the shared objects that are produced by the bazel build + # before they get pulled into the wheel build process. + # we have to do this change here because setting rpath + # using bazel requires the rpath to be valid during the build + # which won't be correct until we make changes to + # the xla/tsl/jax plugin build + + try: + subprocess.check_output(["which", "patchelf"]) + except subprocess.CalledProcessError as ex: + mesg = ( + "rocm plugin and kernel wheel builds require patchelf. " + "please install 'patchelf' and run again" + ) + raise Exception(mesg) from ex + + shared_obj_path = os.path.join(plugin_dir, "xla_rocm_plugin.so") + runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib' + # patchelf --force-rpath --set-rpath $RUNPATH $so + subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, shared_obj_path]) + + tmpdir = None sources_path = args.sources_path