-
Notifications
You must be signed in to change notification settings - Fork 5
[rocm-main]: Canonicalize python build versions #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,12 @@ import logging | |
| import os | ||
| import subprocess | ||
| import sys | ||
| from typing import List | ||
|
|
||
|
|
||
| DEFAULT_GPU_DEVICE_TARGETS = ( | ||
| "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" | ||
| ) | ||
|
|
||
|
|
||
| LOG = logging.getLogger("ci_build") | ||
|
|
@@ -37,10 +43,13 @@ def image_by_name(name): | |
| return image_id | ||
|
|
||
|
|
||
| def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num): | ||
| def create_manylinux_build_image( | ||
| rocm_version, rocm_build_job, rocm_build_num, gpu_device_targets: List[str] | ||
| ) -> str: | ||
| image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace( | ||
| ".", "" | ||
| ) | ||
|
|
||
| cmd = [ | ||
| "docker", | ||
| "build", | ||
|
|
@@ -50,6 +59,7 @@ def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num): | |
| "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, | ||
| "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, | ||
| "--tag=%s" % image_name, | ||
| "--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably shouldn't be a part of this PR, but in the future, could we make GPU_DEVICE_TARGETS a comma-separated list everywhere? It confuses the Python arg parser if you use spaces, and that way we don't have to worry about which script needs what
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its intended to be a comma separated one, but since folks keep doing weird things I made the parser deal with both. This part is actually setting an environment variable in the Docker build for the wheel image, so its not used by any python down the stack from here. (I believe it just gets printf'd to a file in the Dockerfile) |
||
| ".", | ||
| ] | ||
|
|
||
|
|
@@ -65,11 +75,15 @@ def dist_wheels( | |
| rocm_build_job="", | ||
| rocm_build_num="", | ||
| compiler="gcc", | ||
| gpu_device_targets: List[str] = None, | ||
| ): | ||
| if not gpu_device_targets: | ||
| gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") | ||
|
|
||
| # We want to make sure the wheels we build are manylinux compliant. We'll | ||
| # do the build in a container. Build the image for this. | ||
| image_name = create_manylinux_build_image( | ||
| rocm_version, rocm_build_job, rocm_build_num | ||
| rocm_version, rocm_build_job, rocm_build_num, gpu_device_targets | ||
| ) | ||
|
|
||
| if xla_path: | ||
|
|
@@ -91,6 +105,8 @@ def dist_wheels( | |
| pyver_string, | ||
| "--compiler", | ||
| compiler, | ||
| "--gpu-device-targets", | ||
| ",".join(gpu_device_targets), | ||
| ] | ||
|
|
||
| if xla_path: | ||
|
|
@@ -127,31 +143,6 @@ def dist_wheels( | |
| ] | ||
| ) | ||
|
|
||
| # Add command for unit tests | ||
| cmd.extend( | ||
| [ | ||
| "&&", | ||
| "bazel", | ||
| "test", | ||
| "-k", | ||
| "--jobs=4", | ||
| "--test_verbose_timeout_warnings=true", | ||
| "--test_output=all", | ||
| "--test_summary=detailed", | ||
| "--local_test_jobs=1", | ||
| "--test_env=JAX_ACCELERATOR_COUNT=%i" % 4, | ||
| "--test_env=JAX_SKIP_SLOW_TESTS=0", | ||
| "--verbose_failures=true", | ||
| "--config=rocm", | ||
| "--action_env=ROCM_PATH=/opt/rocm", | ||
| "--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a", | ||
| "--test_tag_filters=-multiaccelerator", | ||
| "--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform", | ||
| "--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow", | ||
| "//tests:gpu_tests", | ||
| ] | ||
| ) | ||
|
|
||
| LOG.info("Running: %s", cmd) | ||
| _ = subprocess.run(cmd, check=True) | ||
|
|
||
|
|
@@ -196,10 +187,14 @@ def dist_docker( | |
| tag="rocm/jax-dev", | ||
| dockerfile=None, | ||
| keep_image=True, | ||
| gpu_device_targets: List[str] = None, | ||
| ): | ||
| if not dockerfile: | ||
| dockerfile = "build/rocm/Dockerfile.ms" | ||
|
|
||
| if not gpu_device_targets: | ||
| gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") | ||
|
|
||
| python_version = python_versions[0] | ||
|
|
||
| md = _fetch_jax_metadata(xla_path) | ||
|
|
@@ -212,6 +207,7 @@ def dist_docker( | |
| "--target", | ||
| "rt_build", | ||
| "--build-arg=ROCM_VERSION=%s" % rocm_version, | ||
| "--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets), | ||
| "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, | ||
| "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, | ||
| "--build-arg=BASE_DOCKER=%s" % base_docker, | ||
|
|
@@ -278,6 +274,55 @@ def test(image_name, test_cmd): | |
| subprocess.check_call(cmd) | ||
|
|
||
|
|
||
| def canonicalize_python_versions(versions: List[str]): | ||
| if isinstance(versions, str): | ||
| raise ValueError("'versions' must be a list of strings: versions=%r" % versions) | ||
|
|
||
| cleaned = [] | ||
| for v in versions: | ||
| tup = v.split(".") | ||
| major = tup[0] | ||
| minor = tup[1] | ||
| rev = None | ||
| if len(tup) > 2 and tup[2]: | ||
| rev = tup[2] | ||
|
|
||
| cleaned.append("%s.%s" % (major, minor)) | ||
|
|
||
| return cleaned | ||
|
|
||
|
|
||
| def parse_gpu_targets(targets_string): | ||
| # catch case where targets_string was empty. | ||
| # None should already be caught by argparse, but | ||
| # it doesn't hurt to check twice | ||
| if not targets_string: | ||
| targets_string = DEFAULT_GPU_DEVICE_TARGETS | ||
|
|
||
| if "," in targets_string: | ||
| targets = targets_string.split(",") | ||
| elif " " in targets_string: | ||
| targets = targets_string.split(" ") | ||
| else: | ||
| targets = targets_string | ||
|
|
||
| res = [] | ||
| # cleanup and validation | ||
| for t in targets: | ||
| if not t: | ||
| continue | ||
|
|
||
| if not t.startswith("gfx"): | ||
| raise ValueError("Invalid GPU architecture target: %r" % t) | ||
|
|
||
| res.append(t.strip()) | ||
|
|
||
| if not res: | ||
| raise ValueError("GPU_DEVICE_TARGETS cannot be empty") | ||
|
|
||
| return res | ||
|
|
||
|
|
||
| def parse_args(): | ||
| p = argparse.ArgumentParser() | ||
| p.add_argument( | ||
|
|
@@ -289,7 +334,7 @@ def parse_args(): | |
| p.add_argument( | ||
| "--python-versions", | ||
| type=lambda x: x.split(","), | ||
| default="3.12", | ||
| default=["3.12"], | ||
| help="Comma separated list of CPython versions to build wheels for", | ||
| ) | ||
|
|
||
|
|
@@ -321,6 +366,11 @@ def parse_args(): | |
| choices=["gcc", "clang"], | ||
| help="Compiler backend to use when compiling jax/jaxlib", | ||
| ) | ||
| p.add_argument( | ||
| "--gpu-device-targets", | ||
| default=DEFAULT_GPU_DEVICE_TARGETS, | ||
| help="List of AMDGPU device targets passed from job", | ||
| ) | ||
|
|
||
| subp = p.add_subparsers(dest="action", required=True) | ||
|
|
||
|
|
@@ -344,15 +394,18 @@ def parse_args(): | |
| def main(): | ||
| logging.basicConfig(level=logging.INFO) | ||
| args = parse_args() | ||
| gpu_device_targets = parse_gpu_targets(args.gpu_device_targets) | ||
| python_versions = canonicalize_python_versions(args.python_versions) | ||
mrodden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if args.action == "dist_wheels": | ||
| dist_wheels( | ||
| args.rocm_version, | ||
| args.python_versions, | ||
| python_versions, | ||
| args.xla_source_dir, | ||
| args.rocm_build_job, | ||
| args.rocm_build_num, | ||
| compiler=args.compiler, | ||
| gpu_device_targets=gpu_device_targets, | ||
| ) | ||
|
|
||
| elif args.action == "test": | ||
|
|
@@ -361,25 +414,26 @@ def main(): | |
| elif args.action == "dist_docker": | ||
| dist_wheels( | ||
| args.rocm_version, | ||
| args.python_versions, | ||
| python_versions, | ||
| args.xla_source_dir, | ||
| args.rocm_build_job, | ||
| args.rocm_build_num, | ||
| compiler=args.compiler, | ||
| gpu_device_targets=gpu_device_targets, | ||
| ) | ||
| dist_docker( | ||
| args.rocm_version, | ||
| args.base_docker, | ||
| args.python_versions, | ||
| python_versions, | ||
| args.xla_source_dir, | ||
| rocm_build_job=args.rocm_build_job, | ||
| rocm_build_num=args.rocm_build_num, | ||
| tag=args.image_tag, | ||
| dockerfile=args.dockerfile, | ||
| keep_image=args.keep_image, | ||
| gpu_device_targets=gpu_device_targets, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this. Do we plan on running this in a separate Actions workflow in a future PR?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I thought about that as I was adding to these, but its mostly just so I can make sure my parser functions actually do what they are supposed to when developing them. i.e. running locally I think it would be easier to do when we have things move over to the plugin repo. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| # Copyright 2024 The JAX Authors. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason that we're preferring unittest over pytest? pytest is the newer thing to use
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pytest extends unittest, so this works with pytest as well. All unit test things in python extend from |
||
|
|
||
| import importlib.util | ||
| import importlib.machinery | ||
|
|
||
|
|
||
| def load_ci_build(): | ||
| spec = importlib.util.spec_from_loader( | ||
| "ci_build", importlib.machinery.SourceFileLoader("ci_build", "./ci_build") | ||
| ) | ||
| mod = importlib.util.module_from_spec(spec) | ||
| spec.loader.exec_module(mod) | ||
| return mod | ||
|
|
||
|
|
||
| ci_build = load_ci_build() | ||
|
|
||
|
|
||
| class CIBuildTestCase(unittest.TestCase): | ||
| def test_parse_gpu_targets_spaces(self): | ||
| targets = ["gfx908", "gfx940", "gfx1201"] | ||
| r = ci_build.parse_gpu_targets(" ".join(targets)) | ||
| self.assertEqual(r, targets) | ||
|
|
||
| def test_parse_gpu_targets_commas(self): | ||
| targets = ["gfx908", "gfx940", "gfx1201"] | ||
| r = ci_build.parse_gpu_targets(",".join(targets)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could be convinced otherwise, but it makes the test case easier to interpret if we follow a "one test case, one assertion" rule (unless you need to do multiple checks on the same output, check for None, etc). Could we make the " " and "," cases separate test cases?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I cheated here, but this is the "good path" cases vs the negative path cases which are below, and have specific failure conditions. I split the difference between one test case for everything vs one test case per assert and ended up with this compromise lol |
||
| self.assertEqual(r, targets) | ||
|
|
||
| def test_parse_gpu_targets_empty_string(self): | ||
| expected = ci_build.DEFAULT_GPU_DEVICE_TARGETS.split(",") | ||
| r = ci_build.parse_gpu_targets("") | ||
| self.assertEqual(r, expected) | ||
|
|
||
| def test_parse_gpu_targets_whitespace_only(self): | ||
| self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ") | ||
charleshofer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def test_parse_gpu_targets_invalid_arch(self): | ||
| targets = ["gfx908", "gfx940", "--oops", "/jax"] | ||
| self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets)) | ||
|
|
||
| def test_canonicalize_python_versions(self): | ||
| versions = ["3.10.0", "3.11.0", "3.12.0"] | ||
| exp = ["3.10", "3.11", "3.12"] | ||
| res = ci_build.canonicalize_python_versions(versions) | ||
| self.assertEqual(res, exp) | ||
|
|
||
| def test_canonicalize_python_versions_scalar(self): | ||
| versions = ["3.10.0"] | ||
| exp = ["3.10"] | ||
| res = ci_build.canonicalize_python_versions(versions) | ||
| self.assertEqual(res, exp) | ||
|
|
||
| def test_canonicalize_python_versions_no_revision_part(self): | ||
| versions = ["3.10", "3.11"] | ||
| res = ci_build.canonicalize_python_versions(versions) | ||
| self.assertEqual(res, versions) | ||
|
|
||
| def test_canonicalize_python_versions_string(self): | ||
| versions = "3.10.0" | ||
| self.assertRaises(ValueError, ci_build.canonicalize_python_versions, versions) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Uh oh!
There was an error while loading. Please reload this page.