diff --git a/build/build.py b/build/build.py index 4954ac4a9800..17a4d309bb71 100755 --- a/build/build.py +++ b/build/build.py @@ -746,16 +746,25 @@ async def main(): python_tag = "py" else: python_tag = "cp" + + if args.rocm_path: + rocm_tag = args.rocm_path.split("-")[-1] + logger.info(f"Using ROCm tag: {rocm_tag}") + else: + rocm_tag = None + utils.copy_individual_files( bazel_dir, output_path, f"{wheel_dir}*{wheel_version_suffix}-{python_tag}*.whl", - ) + rocm_tag=rocm_tag, + ) if wheel == "jax": utils.copy_individual_files( bazel_dir, output_path, f"{wheel_dir}*{wheel_version_suffix}.tar.gz", + rocm_tag=rocm_tag, ) # Exit with success if all wheels in the list were built successfully. diff --git a/build/tools/utils.py b/build/tools/utils.py index 90654b24cf7b..067f19a3df2e 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -265,13 +265,20 @@ def copy_dir_recursively(src, dst): logging.info("Editable wheel path: %s" % dst) -def copy_individual_files(src: str, dst: str, glob_pattern: str): +def copy_individual_files(src: str, dst: str, glob_pattern: str, rocm_tag: str): os.makedirs(dst, exist_ok=True) logging.debug( f"Copying files matching pattern {glob_pattern!r} from {src!r} to {dst!r}" ) for f in glob.glob(os.path.join(src, glob_pattern)): - dst_file = os.path.join(dst, os.path.basename(f)) + if rocm_tag: + logging.info(f"Adding Rocm tag {rocm_tag} to file {f}") + f_list = f.split("-") + f_list[2] = f_list[2]+"+rocm"+rocm_tag + new_f = "-".join(f_list) + else: + new_f = f + dst_file = os.path.join(dst, os.path.basename(new_f)) if os.path.exists(dst_file): os.remove(dst_file) shutil.copy2(f, dst_file)