diff --git a/build/rocm/ci_build b/build/rocm/ci_build index ee2c8698d346..ee2fec73b7f9 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -25,6 +25,12 @@ import logging import os import subprocess import sys +from typing import List + + +DEFAULT_GPU_DEVICE_TARGETS = ( + "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +) LOG = logging.getLogger("ci_build") @@ -37,10 +43,13 @@ def image_by_name(name): return image_id -def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num): +def create_manylinux_build_image( + rocm_version, rocm_build_job, rocm_build_num, gpu_device_targets: List[str] +) -> str: image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace( ".", "" ) + cmd = [ "docker", "build", @@ -50,6 +59,7 @@ def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num): "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--tag=%s" % image_name, + "--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets), ".", ] @@ -65,11 +75,15 @@ def dist_wheels( rocm_build_job="", rocm_build_num="", compiler="gcc", + gpu_device_targets: List[str] = None, ): + if not gpu_device_targets: + gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") + # We want to make sure the wheels we build are manylinux compliant. We'll # do the build in a container. Build the image for this. image_name = create_manylinux_build_image( - rocm_version, rocm_build_job, rocm_build_num + rocm_version, rocm_build_job, rocm_build_num, gpu_device_targets ) if xla_path: @@ -91,6 +105,8 @@ def dist_wheels( pyver_string, "--compiler", compiler, + "--gpu-device-targets", + ",".join(gpu_device_targets), ] if xla_path: @@ -127,31 +143,6 @@ def dist_wheels( ] ) - # Add command for unit tests - cmd.extend( - [ - "&&", - "bazel", - "test", - "-k", - "--jobs=4", - "--test_verbose_timeout_warnings=true", - "--test_output=all", - "--test_summary=detailed", - "--local_test_jobs=1", - "--test_env=JAX_ACCELERATOR_COUNT=%i" % 4, - "--test_env=JAX_SKIP_SLOW_TESTS=0", - "--verbose_failures=true", - "--config=rocm", - "--action_env=ROCM_PATH=/opt/rocm", - "--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a", - "--test_tag_filters=-multiaccelerator", - "--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform", - "--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow", - "//tests:gpu_tests", - ] - ) - LOG.info("Running: %s", cmd) _ = subprocess.run(cmd, check=True) @@ -196,10 +187,14 @@ def dist_docker( tag="rocm/jax-dev", dockerfile=None, keep_image=True, + gpu_device_targets: List[str] = None, ): if not dockerfile: dockerfile = "build/rocm/Dockerfile.ms" + if not gpu_device_targets: + gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") + python_version = python_versions[0] md = _fetch_jax_metadata(xla_path) @@ -212,6 +207,7 @@ def dist_docker( "--target", "rt_build", "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets), "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--build-arg=BASE_DOCKER=%s" % base_docker, @@ -278,6 +274,55 @@ def test(image_name, test_cmd): subprocess.check_call(cmd) +def canonicalize_python_versions(versions: List[str]): + if isinstance(versions, str): + raise ValueError("'versions' must be a list of strings: versions=%r" % versions) + + cleaned = [] + for v in versions: + tup = v.split(".") + major = tup[0] + minor = tup[1] + rev = None + if len(tup) > 2 and tup[2]: + rev = tup[2] + + cleaned.append("%s.%s" % (major, minor)) + + return cleaned + + +def parse_gpu_targets(targets_string): + # catch case where targets_string was empty. + # None should already be caught by argparse, but + # it doesn't hurt to check twice + if not targets_string: + targets_string = DEFAULT_GPU_DEVICE_TARGETS + + if "," in targets_string: + targets = targets_string.split(",") + elif " " in targets_string: + targets = targets_string.split(" ") + else: + targets = targets_string + + res = [] + # cleanup and validation + for t in targets: + if not t: + continue + + if not t.startswith("gfx"): + raise ValueError("Invalid GPU architecture target: %r" % t) + + res.append(t.strip()) + + if not res: + raise ValueError("GPU_DEVICE_TARGETS cannot be empty") + + return res + + def parse_args(): p = argparse.ArgumentParser() p.add_argument( @@ -289,7 +334,7 @@ def parse_args(): p.add_argument( "--python-versions", type=lambda x: x.split(","), - default="3.12", + default=["3.12"], help="Comma separated list of CPython versions to build wheels for", ) @@ -321,6 +366,11 @@ def parse_args(): choices=["gcc", "clang"], help="Compiler backend to use when compiling jax/jaxlib", ) + p.add_argument( + "--gpu-device-targets", + default=DEFAULT_GPU_DEVICE_TARGETS, + help="List of AMDGPU device targets passed from job", + ) subp = p.add_subparsers(dest="action", required=True) @@ -344,15 +394,18 @@ def parse_args(): def main(): logging.basicConfig(level=logging.INFO) args = parse_args() + gpu_device_targets = parse_gpu_targets(args.gpu_device_targets) + python_versions = canonicalize_python_versions(args.python_versions) if args.action == "dist_wheels": dist_wheels( args.rocm_version, - args.python_versions, + python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, compiler=args.compiler, + gpu_device_targets=gpu_device_targets, ) elif args.action == "test": @@ -361,25 +414,26 @@ def main(): elif args.action == "dist_docker": dist_wheels( args.rocm_version, - args.python_versions, + python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, compiler=args.compiler, + gpu_device_targets=gpu_device_targets, ) dist_docker( args.rocm_version, args.base_docker, - args.python_versions, + python_versions, args.xla_source_dir, rocm_build_job=args.rocm_build_job, rocm_build_num=args.rocm_build_num, tag=args.image_tag, dockerfile=args.dockerfile, keep_image=args.keep_image, + gpu_device_targets=gpu_device_targets, ) if __name__ == "__main__": main() - diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 386f70ee1a96..d8b98490b076 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -51,6 +51,7 @@ ROCM_BUILD_NUM="" BASE_DOCKER="ubuntu:22.04" CUSTOM_INSTALL="" JAX_USE_CLANG="" +GPU_DEVICE_TARGETS="" POSITIONAL_ARGS=() RUNTIME_FLAG=0 @@ -98,6 +99,18 @@ while [[ $# -gt 0 ]]; do JAX_USE_CLANG="$2" shift 2 ;; + --gpu_device_targets) + if [[ "$2" == "--custom_install" ]]; then + GPU_DEVICE_TARGETS="" + shift 2 + elif [[ -n "$2" ]]; then + GPU_DEVICE_TARGETS="$2" + shift 2 + else + GPU_DEVICE_TARGETS="" + shift 1 + fi + ;; *) POSITIONAL_ARGS+=("$1") shift @@ -164,6 +177,7 @@ fi --rocm-build-job=$ROCM_BUILD_JOB \ --rocm-build-num=$ROCM_BUILD_NUM \ --compiler=$JAX_COMPILER \ + --gpu-device-targets="${GPU_DEVICE_TARGETS}" \ dist_docker \ --dockerfile $DOCKERFILE_PATH \ --image-tag $DOCKER_IMG_NAME diff --git a/build/rocm/test_ci_build.py b/build/rocm/test_ci_build.py new file mode 100644 index 000000000000..354da937d3e5 --- /dev/null +++ b/build/rocm/test_ci_build.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import importlib.util +import importlib.machinery + + +def load_ci_build(): + spec = importlib.util.spec_from_loader( + "ci_build", importlib.machinery.SourceFileLoader("ci_build", "./ci_build") + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +ci_build = load_ci_build() + + +class CIBuildTestCase(unittest.TestCase): + def test_parse_gpu_targets_spaces(self): + targets = ["gfx908", "gfx940", "gfx1201"] + r = ci_build.parse_gpu_targets(" ".join(targets)) + self.assertEqual(r, targets) + + def test_parse_gpu_targets_commas(self): + targets = ["gfx908", "gfx940", "gfx1201"] + r = ci_build.parse_gpu_targets(",".join(targets)) + self.assertEqual(r, targets) + + def test_parse_gpu_targets_empty_string(self): + expected = ci_build.DEFAULT_GPU_DEVICE_TARGETS.split(",") + r = ci_build.parse_gpu_targets("") + self.assertEqual(r, expected) + + def test_parse_gpu_targets_whitespace_only(self): + self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ") + + def test_parse_gpu_targets_invalid_arch(self): + targets = ["gfx908", "gfx940", "--oops", "/jax"] + self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets)) + + def test_canonicalize_python_versions(self): + versions = ["3.10.0", "3.11.0", "3.12.0"] + exp = ["3.10", "3.11", "3.12"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, exp) + + def test_canonicalize_python_versions_scalar(self): + versions = ["3.10.0"] + exp = ["3.10"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, exp) + + def test_canonicalize_python_versions_no_revision_part(self): + versions = ["3.10", "3.11"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, versions) + + def test_canonicalize_python_versions_string(self): + versions = "3.10.0" + self.assertRaises(ValueError, ci_build.canonicalize_python_versions, versions) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index 139a1fdd8fa4..b714219b786d 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -30,12 +30,15 @@ import subprocess import shutil import sys +from typing import List LOG = logging.getLogger(__name__) -GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201" +DEFAULT_GPU_DEVICE_TARGETS = ( + "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +) def build_rocm_path(rocm_version_str): @@ -46,11 +49,11 @@ def build_rocm_path(rocm_version_str): return os.path.realpath("/opt/rocm") -def update_rocm_targets(rocm_path, targets): +def update_rocm_targets(rocm_path: str, targets: List[str]): target_fp = os.path.join(rocm_path, "bin/target.lst") version_fp = os.path.join(rocm_path, ".info/version") with open(target_fp, "w") as fd: - fd.write("%s\n" % targets) + fd.write("%s\n" % " ".join(targets)) # mimic touch open(version_fp, "a").close() @@ -251,7 +254,7 @@ def parse_args(): ) p.add_argument( "--python-versions", - default=["3.10.19,3.12"], + default="3.10.19,3.12", help="Comma separated CPython versions that wheels will be built and output for", ) p.add_argument( @@ -266,6 +269,11 @@ def parse_args(): default="gcc", help="Compiler backend to use when compiling jax/jaxlib", ) + p.add_argument( + "--gpu-device-targets", + default=DEFAULT_GPU_DEVICE_TARGETS, + help="Comma separated list of GPU device targets passed from job", + ) p.add_argument("jax_path", help="Directory where JAX source directory is located") @@ -286,6 +294,7 @@ def find_wheels(path): def main(): args = parse_args() python_versions = args.python_versions.split(",") + gpu_device_targets = args.gpu_device_targets.split(",") print("ROCM_VERSION=%s" % args.rocm_version) print("PYTHON_VERSIONS=%r" % python_versions) @@ -295,7 +304,7 @@ def main(): rocm_path = build_rocm_path(args.rocm_version) - update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) + update_rocm_targets(rocm_path, gpu_device_targets) for py in python_versions: build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler)