Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,16 +746,25 @@ async def main():
python_tag = "py"
else:
python_tag = "cp"

if args.rocm_path:
rocm_tag = args.rocm_path.split("-")[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure this isn't going to capture the full ROCm version. I've never seen a ROCm install called anything like /opt/rocm-7.1.0devabc123. Does ROCm store the full version string somewhere, like maybe in .info/version?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you could repurpose the --rocm_version argument for this use-case. Have it accept the full ROCm version and use that as a tag to jaxlib

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info(f"Using ROCm tag: {rocm_tag}")
else:
rocm_tag = None

utils.copy_individual_files(
bazel_dir,
output_path,
f"{wheel_dir}*{wheel_version_suffix}-{python_tag}*.whl",
)
rocm_tag=rocm_tag,
)
if wheel == "jax":
utils.copy_individual_files(
bazel_dir,
output_path,
f"{wheel_dir}*{wheel_version_suffix}.tar.gz",
rocm_tag=rocm_tag,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the proper place for this. We don't ship the jax wheel, only jaxlib and the jax-rocm7... wheels. Whatever version changes we make to the jaxlib wheel have to be compatible with upstream jax.

)

# Exit with success if all wheels in the list were built successfully.
Expand Down
11 changes: 9 additions & 2 deletions build/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,20 @@ def copy_dir_recursively(src, dst):
logging.info("Editable wheel path: %s" % dst)


def copy_individual_files(src: str, dst: str, glob_pattern: str):
def copy_individual_files(src: str, dst: str, glob_pattern: str, rocm_tag: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typing on this is wrong. It should be Optional[str] and it should take None as a default.

os.makedirs(dst, exist_ok=True)
logging.debug(
f"Copying files matching pattern {glob_pattern!r} from {src!r} to {dst!r}"
)
for f in glob.glob(os.path.join(src, glob_pattern)):
dst_file = os.path.join(dst, os.path.basename(f))
if rocm_tag:
logging.info(f"Adding Rocm tag {rocm_tag} to file {f}")
f_list = f.split("-")
f_list[2] = f_list[2]+"+rocm"+rocm_tag
new_f = "-".join(f_list)
else:
new_f = f
dst_file = os.path.join(dst, os.path.basename(new_f))
if os.path.exists(dst_file):
os.remove(dst_file)
shutil.copy2(f, dst_file)
Expand Down