-
Notifications
You must be signed in to change notification settings - Fork 5
jaxlib with rocm version extra tag #555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rocm-jaxlib-v0.8.0
Are you sure you want to change the base?
jaxlib with rocm version extra tag #555
Conversation
08c44fc to
ffe67f5
Compare
| python_tag = "cp" | ||
|
|
||
| if args.rocm_path: | ||
| rocm_tag = args.rocm_path.split("-")[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty sure this isn't going to capture the full ROCm version. I've never seen a ROCm install called anything like /opt/rocm-7.1.0devabc123. Does ROCm store the full version string somewhere, like maybe in .info/version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or you could repurpose the --rocm_version argument for this use-case. Have it accept the full ROCm version and use that as a tag to jaxlib
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we might have an issue also with https://github.com/ROCm/rocm-jax/pull/179/files#diff-43328ab5df558db4c04acd94cbfa257445eaac24da8d534c2d8163cf4507a6bdR30
| bazel_dir, | ||
| output_path, | ||
| f"{wheel_dir}*{wheel_version_suffix}.tar.gz", | ||
| rocm_tag=rocm_tag, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the proper place for this. We don't ship the jax wheel, only jaxlib and the jax-rocm7... wheels. Whatever version changes we make to the jaxlib wheel have to be compatible with upstream jax.
charleshofer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did some local testing, and, as is, this change has some problems.
If I pass in --rocm_path=/opt/rocm-7.2.0 I get a wheel called jaxlib-0.8.0.dev0+selfbuilt+rocm7.2.0-cp312-cp312-manylinux_2_27_x86_64.whl, which is not a valid wheel name.
Paths to ROCm that don't contain the version number are also valid, so a path like /opt/rocm/ will break this.
I'd suggest two things:
- Use
--rocm_versioninstead of--rocm_path.--rocm_pathwill not get you full version names like7.1.0dev12345. - Don't add the other version suffix if
--rocm_versionis set.
|
|
||
|
|
||
| def copy_individual_files(src: str, dst: str, glob_pattern: str): | ||
| def copy_individual_files(src: str, dst: str, glob_pattern: str, rocm_tag: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The typing on this is wrong. It should be Optional[str] and it should take None as a default.
| python_tag = "cp" | ||
|
|
||
| if args.rocm_path: | ||
| rocm_tag = args.rocm_path.split("-")[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or you could repurpose the --rocm_version argument for this use-case. Have it accept the full ROCm version and use that as a tag to jaxlib
Motivation - SWDEV-563344
SWDEV-563344
[ROCm QA][TheRock] Build specific JAX whl are not available in artifactory
This PR adds an wheel tagging feature enable rocm version management of JAX ROCm jaxlib wheels in automated build pipelines, which is essential for maintaining consistent versioning across different ROCm builds and integrating with existing S3 upload workflows.
Technical Details
Pattern Transformation:
jaxlib-0.7.1-cp312-cp312-manylinux_2_28_x86_64.whl
→ jaxlib-0.7.1+rocm7.1.0a20251107-cp312-cp312-manylinux_2_28_x86_64.whl
Test Result
Submission Checklist
[*]Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.