From d0f4d31967cd0caba96d096989376f93262e804a Mon Sep 17 00:00:00 2001 From: Mathew Odden <1471252+mrodden@users.noreply.github.com> Date: Mon, 17 Mar 2025 17:15:53 -0500 Subject: [PATCH 1/3] Fix missing defaults for GPU target arch'es (#293) Move defaults for these up to ci_build python We also add some error catching logic in the python script to clean up weirdness coming from the build scripts. Also clean up some formatting and typos (cherry picked from commit ef1d56184b851ce5de4e797640e2bd606c8632ff) (cherry picked from commit ba2e112869258db0e1b6ebc07f7c5342b3090a95) --- build/rocm/ci_build | 58 +++++++++++++++++++++++++++++++- build/rocm/ci_build.sh | 14 ++++++++ build/rocm/test_ci_build.py | 58 ++++++++++++++++++++++++++++++++ build/rocm/tools/build_wheels.py | 19 ++++++++--- 4 files changed, 143 insertions(+), 6 deletions(-) create mode 100644 build/rocm/test_ci_build.py diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 5f79f11502d5..18759ae03571 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -24,6 +24,10 @@ import argparse import os import subprocess import sys +from typing import List + + +DEFAULT_GPU_DEVICE_TARGETS = "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" def image_by_name(name): @@ -40,7 +44,11 @@ def dist_wheels( rocm_build_job="", rocm_build_num="", compiler="gcc", + gpu_device_targets : List[str] = None, ): + if not gpu_device_targets: + gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") + if xla_path: xla_path = os.path.abspath(xla_path) @@ -63,6 +71,7 @@ def dist_wheels( "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--tag=%s" % image, + "--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets), ".", ] @@ -85,6 +94,8 @@ def dist_wheels( pyver_string, "--compiler", compiler, + "--gpu-device-targets", + ",".join(gpu_device_targets), ] if xla_path: @@ -158,10 +169,14 @@ def dist_docker( tag="rocm/jax-dev", dockerfile=None, keep_image=True, + gpu_device_targets : List[str] = None, ): if not dockerfile: dockerfile = "build/rocm/Dockerfile.ms" + if not gpu_device_targets: + gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") + python_version = python_versions[0] md = _fetch_jax_metadata(xla_path) @@ -174,6 +189,7 @@ def dist_docker( "--target", "rt_build", "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets), "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--build-arg=BASE_DOCKER=%s" % base_docker, @@ -238,6 +254,37 @@ def test(image_name): subprocess.check_call(cmd) +def parse_gpu_targets(targets_string): + # catch case where targets_string was empty. + # None should already be caught by argparse, but + # it doesn't hurt to check twice + if not targets_string: + targets_string = DEFAULT_GPU_DEVICE_TARGETS + + if "," in targets_string: + targets = targets_string.split(",") + elif " " in targets_string: + targets = targets_string.split(" ") + else: + targets = targets_string + + res = [] + # cleanup and validation + for t in targets: + if not t: + continue + + if not t.startswith("gfx"): + raise ValueError("Invalid GPU architecture target: %r" % t) + + res.append(t.strip()) + + if not res: + raise ValueError("GPU_DEVICE_TARGETS cannot be empty") + + return res + + def parse_args(): p = argparse.ArgumentParser() p.add_argument( @@ -249,7 +296,7 @@ def parse_args(): p.add_argument( "--python-versions", type=lambda x: x.split(","), - default="3.12", + default=["3.12"], help="Comma separated list of CPython versions to build wheels for", ) @@ -281,6 +328,11 @@ def parse_args(): choices=["gcc", "clang"], help="Compiler backend to use when compiling jax/jaxlib", ) + p.add_argument( + "--gpu-device-targets", + default=DEFAULT_GPU_DEVICE_TARGETS, + help="List of AMDGPU device targets passed from job", + ) subp = p.add_subparsers(dest="action", required=True) @@ -299,6 +351,7 @@ def parse_args(): def main(): args = parse_args() + gpu_device_targets = parse_gpu_targets(args.gpu_device_targets) if args.action == "dist_wheels": dist_wheels( @@ -308,6 +361,7 @@ def main(): args.rocm_build_job, args.rocm_build_num, compiler=args.compiler, + gpu_device_targets=gpu_device_targets, ) elif args.action == "test": @@ -321,6 +375,7 @@ def main(): args.rocm_build_job, args.rocm_build_num, compiler=args.compiler, + gpu_device_targets=gpu_device_targets, ) dist_docker( args.rocm_version, @@ -332,6 +387,7 @@ def main(): tag=args.image_tag, dockerfile=args.dockerfile, keep_image=args.keep_image, + gpu_device_targets=gpu_device_targets, ) diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 386f70ee1a96..d8b98490b076 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -51,6 +51,7 @@ ROCM_BUILD_NUM="" BASE_DOCKER="ubuntu:22.04" CUSTOM_INSTALL="" JAX_USE_CLANG="" +GPU_DEVICE_TARGETS="" POSITIONAL_ARGS=() RUNTIME_FLAG=0 @@ -98,6 +99,18 @@ while [[ $# -gt 0 ]]; do JAX_USE_CLANG="$2" shift 2 ;; + --gpu_device_targets) + if [[ "$2" == "--custom_install" ]]; then + GPU_DEVICE_TARGETS="" + shift 2 + elif [[ -n "$2" ]]; then + GPU_DEVICE_TARGETS="$2" + shift 2 + else + GPU_DEVICE_TARGETS="" + shift 1 + fi + ;; *) POSITIONAL_ARGS+=("$1") shift @@ -164,6 +177,7 @@ fi --rocm-build-job=$ROCM_BUILD_JOB \ --rocm-build-num=$ROCM_BUILD_NUM \ --compiler=$JAX_COMPILER \ + --gpu-device-targets="${GPU_DEVICE_TARGETS}" \ dist_docker \ --dockerfile $DOCKERFILE_PATH \ --image-tag $DOCKER_IMG_NAME diff --git a/build/rocm/test_ci_build.py b/build/rocm/test_ci_build.py new file mode 100644 index 000000000000..03501c568619 --- /dev/null +++ b/build/rocm/test_ci_build.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import importlib.util +import importlib.machinery + + +def load_ci_build(): + spec = importlib.util.spec_from_loader( + "ci_build", importlib.machinery.SourceFileLoader("ci_build", "./ci_build") + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +ci_build = load_ci_build() + + +class CIBuildTestCase(unittest.TestCase): + def test_parse_gpu_targets(self): + targets = ["gfx908", "gfx940", "gfx1201"] + + r = ci_build.parse_gpu_targets(" ".join(targets)) + self.assertEqual(r, targets) + + r = ci_build.parse_gpu_targets(",".join(targets)) + self.assertEqual(r, targets) + + def test_parse_gpu_targets_empty_string(self): + expected = ci_build.DEFAULT_GPU_DEVICE_TARGETS.split(",") + r = ci_build.parse_gpu_targets("") + self.assertEqual(r, expected) + + self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ") + + def test_parse_gpu_targets_invalid_arch(self): + targets = ["gfx908", "gfx940", "--oops", "/jax"] + self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets)) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index 36d2c35d2f36..65398cdb4da1 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -30,12 +30,15 @@ import subprocess import shutil import sys +from typing import List LOG = logging.getLogger(__name__) -GPU_DEVICE_TARGETS = "gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201" +DEFAULT_GPU_DEVICE_TARGETS = ( + "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +) def build_rocm_path(rocm_version_str): @@ -46,11 +49,11 @@ def build_rocm_path(rocm_version_str): return os.path.realpath("/opt/rocm") -def update_rocm_targets(rocm_path, targets): +def update_rocm_targets(rocm_path: str, targets: List[str]): target_fp = os.path.join(rocm_path, "bin/target.lst") version_fp = os.path.join(rocm_path, ".info/version") with open(target_fp, "w") as fd: - fd.write("%s\n" % targets) + fd.write("%s\n" % " ".join(targets)) # mimic touch open(version_fp, "a").close() @@ -250,7 +253,7 @@ def parse_args(): ) p.add_argument( "--python-versions", - default=["3.10.19,3.12"], + default="3.10.19,3.12", help="Comma separated CPython versions that wheels will be built and output for", ) p.add_argument( @@ -265,6 +268,11 @@ def parse_args(): default="gcc", help="Compiler backend to use when compiling jax/jaxlib", ) + p.add_argument( + "--gpu-device-targets", + default=DEFAULT_GPU_DEVICE_TARGETS, + help="Comma separated list of GPU device targets passed from job", + ) p.add_argument("jax_path", help="Directory where JAX source directory is located") @@ -285,6 +293,7 @@ def find_wheels(path): def main(): args = parse_args() python_versions = args.python_versions.split(",") + gpu_device_targets = args.gpu_device_targets.split(",") print("ROCM_VERSION=%s" % args.rocm_version) print("PYTHON_VERSIONS=%r" % python_versions) @@ -294,7 +303,7 @@ def main(): rocm_path = build_rocm_path(args.rocm_version) - update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) + update_rocm_targets(rocm_path, gpu_device_targets) for py in python_versions: build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler) From 3a2fec78e5f53f1fbd929d74c48aa2f9bc2f233b Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Fri, 28 Mar 2025 12:41:06 -0500 Subject: [PATCH 2/3] Split parse_gpu_targets test cases (cherry picked from commit 5024cf8aef13ae049276093716621343b0e41742) --- build/rocm/test_ci_build.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/build/rocm/test_ci_build.py b/build/rocm/test_ci_build.py index 03501c568619..99ab75b2a999 100644 --- a/build/rocm/test_ci_build.py +++ b/build/rocm/test_ci_build.py @@ -33,12 +33,13 @@ def load_ci_build(): class CIBuildTestCase(unittest.TestCase): - def test_parse_gpu_targets(self): + def test_parse_gpu_targets_spaces(self): targets = ["gfx908", "gfx940", "gfx1201"] - r = ci_build.parse_gpu_targets(" ".join(targets)) self.assertEqual(r, targets) + def test_parse_gpu_targets_commas(self): + targets = ["gfx908", "gfx940", "gfx1201"] r = ci_build.parse_gpu_targets(",".join(targets)) self.assertEqual(r, targets) @@ -47,6 +48,7 @@ def test_parse_gpu_targets_empty_string(self): r = ci_build.parse_gpu_targets("") self.assertEqual(r, expected) + def test_parse_gpu_targets_whitespace_only(self): self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ") def test_parse_gpu_targets_invalid_arch(self): From 8c1eb0a2ef987e0d9f6f1eb92685ee33c78edba1 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Wed, 26 Mar 2025 17:27:31 -0500 Subject: [PATCH 3/3] Add canonicalization of python versions The ci_build parsing will now drop the revision portion of the requested python version targets to build for, since users were passing in old buggy versions. Using only major.minor will result in the latest of the minor series being used, and won't affect any build outputs like bytecode or ABI bindings. Those are tied to major.minor versions of CPython. (cherry picked from commit 5bf08753d6da9ef46d2c6b5388dfa68952212816) --- build/rocm/ci_build | 33 +++++++++++++++++++++++++++------ build/rocm/test_ci_build.py | 21 +++++++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 18759ae03571..7c76fba83803 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -27,7 +27,9 @@ import sys from typing import List -DEFAULT_GPU_DEVICE_TARGETS = "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +DEFAULT_GPU_DEVICE_TARGETS = ( + "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +) def image_by_name(name): @@ -44,7 +46,7 @@ def dist_wheels( rocm_build_job="", rocm_build_num="", compiler="gcc", - gpu_device_targets : List[str] = None, + gpu_device_targets: List[str] = None, ): if not gpu_device_targets: gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") @@ -169,7 +171,7 @@ def dist_docker( tag="rocm/jax-dev", dockerfile=None, keep_image=True, - gpu_device_targets : List[str] = None, + gpu_device_targets: List[str] = None, ): if not dockerfile: dockerfile = "build/rocm/Dockerfile.ms" @@ -254,6 +256,24 @@ def test(image_name): subprocess.check_call(cmd) +def canonicalize_python_versions(versions: List[str]): + if isinstance(versions, str): + raise ValueError("'versions' must be a list of strings: versions=%r" % versions) + + cleaned = [] + for v in versions: + tup = v.split(".") + major = tup[0] + minor = tup[1] + rev = None + if len(tup) > 2 and tup[2]: + rev = tup[2] + + cleaned.append("%s.%s" % (major, minor)) + + return cleaned + + def parse_gpu_targets(targets_string): # catch case where targets_string was empty. # None should already be caught by argparse, but @@ -352,11 +372,12 @@ def parse_args(): def main(): args = parse_args() gpu_device_targets = parse_gpu_targets(args.gpu_device_targets) + python_versions = canonicalize_python_versions(args.python_versions) if args.action == "dist_wheels": dist_wheels( args.rocm_version, - args.python_versions, + python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, @@ -370,7 +391,7 @@ def main(): elif args.action == "dist_docker": dist_wheels( args.rocm_version, - args.python_versions, + python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, @@ -380,7 +401,7 @@ def main(): dist_docker( args.rocm_version, args.base_docker, - args.python_versions, + python_versions, args.xla_source_dir, rocm_build_job=args.rocm_build_job, rocm_build_num=args.rocm_build_num, diff --git a/build/rocm/test_ci_build.py b/build/rocm/test_ci_build.py index 99ab75b2a999..354da937d3e5 100644 --- a/build/rocm/test_ci_build.py +++ b/build/rocm/test_ci_build.py @@ -55,6 +55,27 @@ def test_parse_gpu_targets_invalid_arch(self): targets = ["gfx908", "gfx940", "--oops", "/jax"] self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets)) + def test_canonicalize_python_versions(self): + versions = ["3.10.0", "3.11.0", "3.12.0"] + exp = ["3.10", "3.11", "3.12"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, exp) + + def test_canonicalize_python_versions_scalar(self): + versions = ["3.10.0"] + exp = ["3.10"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, exp) + + def test_canonicalize_python_versions_no_revision_part(self): + versions = ["3.10", "3.11"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, versions) + + def test_canonicalize_python_versions_string(self): + versions = "3.10.0" + self.assertRaises(ValueError, ci_build.canonicalize_python_versions, versions) + if __name__ == "__main__": unittest.main()