Skip to content

Commit 024e331

Browse files
Merge pull request jax-ml#25084 from ROCm:ci_rocm_version
PiperOrigin-RevId: 700241231
2 parents f828f2d + e8934b9 commit 024e331

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

jax_plugins/rocm/plugin_setup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
project_name = f"jax-rocm{rocm_version}-plugin"
2323
package_name = f"jax_rocm{rocm_version}_plugin"
2424

25+
# Extract ROCm version from the `ROCM_PATH` environment variable.
26+
default_rocm_path = "/opt/rocm"
27+
rocm_path = os.getenv("ROCM_PATH", default_rocm_path)
28+
rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown"
29+
2530
def load_version_module(pkg_path):
2631
spec = importlib.util.spec_from_file_location(
2732
'version', os.path.join(pkg_path, 'version.py'))
@@ -43,7 +48,7 @@ def has_ext_modules(self):
4348
name=project_name,
4449
version=__version__,
4550
cmdclass=_cmdclass,
46-
description="JAX Plugin for AMD GPUs",
51+
description=f"JAX Plugin for AMD GPUs (ROCm:{rocm_detected_version})",
4752
long_description="",
4853
long_description_content_type="text/markdown",
4954
author="Ruturaj4",

jax_plugins/rocm/setup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
project_name = f"jax-rocm{rocm_version}-pjrt"
2222
package_name = f"jax_plugins.xla_rocm{rocm_version}"
2323

24+
# Extract ROCm version from the `ROCM_PATH` environment variable.
25+
default_rocm_path = "/opt/rocm"
26+
rocm_path = os.getenv("ROCM_PATH", default_rocm_path)
27+
rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown"
28+
2429
def load_version_module(pkg_path):
2530
spec = importlib.util.spec_from_file_location(
2631
'version', os.path.join(pkg_path, 'version.py'))
@@ -41,7 +46,7 @@ def load_version_module(pkg_path):
4146
setup(
4247
name=project_name,
4348
version=__version__,
44-
description="JAX XLA PJRT Plugin for AMD GPUs",
49+
description=f"JAX XLA PJRT Plugin for AMD GPUs (ROCm:{rocm_detected_version})",
4550
long_description="",
4651
long_description_content_type="text/markdown",
4752
author="Ruturaj4",

0 commit comments

Comments
 (0)