Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 86 additions & 32 deletions build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ import logging
import os
import subprocess
import sys
from typing import List


DEFAULT_GPU_DEVICE_TARGETS = (
"gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
)


LOG = logging.getLogger("ci_build")
Expand All @@ -37,10 +43,13 @@ def image_by_name(name):
return image_id


def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
def create_manylinux_build_image(
rocm_version, rocm_build_job, rocm_build_num, gpu_device_targets: List[str]
) -> str:
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(
".", ""
)

cmd = [
"docker",
"build",
Expand All @@ -50,6 +59,7 @@ def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image_name,
"--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably shouldn't be a part of this PR, but in the future, could we make GPU_DEVICE_TARGETS a comma-separated list everywhere? It confuses the Python arg parser if you use spaces, and that way we don't have to worry about which script needs what

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its intended to be a comma separated one, but since folks keep doing weird things I made the parser deal with both.

This part is actually setting an environment variable in the Docker build for the wheel image, so its not used by any python down the stack from here. (I believe it just gets printf'd to a file in the Dockerfile)

".",
]

Expand All @@ -65,11 +75,15 @@ def dist_wheels(
rocm_build_job="",
rocm_build_num="",
compiler="gcc",
gpu_device_targets: List[str] = None,
):
if not gpu_device_targets:
gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",")

# We want to make sure the wheels we build are manylinux compliant. We'll
# do the build in a container. Build the image for this.
image_name = create_manylinux_build_image(
rocm_version, rocm_build_job, rocm_build_num
rocm_version, rocm_build_job, rocm_build_num, gpu_device_targets
)

if xla_path:
Expand All @@ -91,6 +105,8 @@ def dist_wheels(
pyver_string,
"--compiler",
compiler,
"--gpu-device-targets",
",".join(gpu_device_targets),
]

if xla_path:
Expand Down Expand Up @@ -127,31 +143,6 @@ def dist_wheels(
]
)

# Add command for unit tests
cmd.extend(
[
"&&",
"bazel",
"test",
"-k",
"--jobs=4",
"--test_verbose_timeout_warnings=true",
"--test_output=all",
"--test_summary=detailed",
"--local_test_jobs=1",
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
"--test_env=JAX_SKIP_SLOW_TESTS=0",
"--verbose_failures=true",
"--config=rocm",
"--action_env=ROCM_PATH=/opt/rocm",
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
"--test_tag_filters=-multiaccelerator",
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
"//tests:gpu_tests",
]
)

LOG.info("Running: %s", cmd)
_ = subprocess.run(cmd, check=True)

Expand Down Expand Up @@ -196,10 +187,14 @@ def dist_docker(
tag="rocm/jax-dev",
dockerfile=None,
keep_image=True,
gpu_device_targets: List[str] = None,
):
if not dockerfile:
dockerfile = "build/rocm/Dockerfile.ms"

if not gpu_device_targets:
gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",")

python_version = python_versions[0]

md = _fetch_jax_metadata(xla_path)
Expand All @@ -212,6 +207,7 @@ def dist_docker(
"--target",
"rt_build",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets),
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--build-arg=BASE_DOCKER=%s" % base_docker,
Expand Down Expand Up @@ -278,6 +274,55 @@ def test(image_name, test_cmd):
subprocess.check_call(cmd)


def canonicalize_python_versions(versions: List[str]):
if isinstance(versions, str):
raise ValueError("'versions' must be a list of strings: versions=%r" % versions)

cleaned = []
for v in versions:
tup = v.split(".")
major = tup[0]
minor = tup[1]
rev = None
if len(tup) > 2 and tup[2]:
rev = tup[2]

cleaned.append("%s.%s" % (major, minor))

return cleaned


def parse_gpu_targets(targets_string):
# catch case where targets_string was empty.
# None should already be caught by argparse, but
# it doesn't hurt to check twice
if not targets_string:
targets_string = DEFAULT_GPU_DEVICE_TARGETS

if "," in targets_string:
targets = targets_string.split(",")
elif " " in targets_string:
targets = targets_string.split(" ")
else:
targets = targets_string

res = []
# cleanup and validation
for t in targets:
if not t:
continue

if not t.startswith("gfx"):
raise ValueError("Invalid GPU architecture target: %r" % t)

res.append(t.strip())

if not res:
raise ValueError("GPU_DEVICE_TARGETS cannot be empty")

return res


def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
Expand All @@ -289,7 +334,7 @@ def parse_args():
p.add_argument(
"--python-versions",
type=lambda x: x.split(","),
default="3.12",
default=["3.12"],
help="Comma separated list of CPython versions to build wheels for",
)

Expand Down Expand Up @@ -321,6 +366,11 @@ def parse_args():
choices=["gcc", "clang"],
help="Compiler backend to use when compiling jax/jaxlib",
)
p.add_argument(
"--gpu-device-targets",
default=DEFAULT_GPU_DEVICE_TARGETS,
help="List of AMDGPU device targets passed from job",
)

subp = p.add_subparsers(dest="action", required=True)

Expand All @@ -344,15 +394,18 @@ def parse_args():
def main():
logging.basicConfig(level=logging.INFO)
args = parse_args()
gpu_device_targets = parse_gpu_targets(args.gpu_device_targets)
python_versions = canonicalize_python_versions(args.python_versions)

if args.action == "dist_wheels":
dist_wheels(
args.rocm_version,
args.python_versions,
python_versions,
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
compiler=args.compiler,
gpu_device_targets=gpu_device_targets,
)

elif args.action == "test":
Expand All @@ -361,25 +414,26 @@ def main():
elif args.action == "dist_docker":
dist_wheels(
args.rocm_version,
args.python_versions,
python_versions,
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
compiler=args.compiler,
gpu_device_targets=gpu_device_targets,
)
dist_docker(
args.rocm_version,
args.base_docker,
args.python_versions,
python_versions,
args.xla_source_dir,
rocm_build_job=args.rocm_build_job,
rocm_build_num=args.rocm_build_num,
tag=args.image_tag,
dockerfile=args.dockerfile,
keep_image=args.keep_image,
gpu_device_targets=gpu_device_targets,
)


if __name__ == "__main__":
main()

14 changes: 14 additions & 0 deletions build/rocm/ci_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ ROCM_BUILD_NUM=""
BASE_DOCKER="ubuntu:22.04"
CUSTOM_INSTALL=""
JAX_USE_CLANG=""
GPU_DEVICE_TARGETS=""
POSITIONAL_ARGS=()

RUNTIME_FLAG=0
Expand Down Expand Up @@ -98,6 +99,18 @@ while [[ $# -gt 0 ]]; do
JAX_USE_CLANG="$2"
shift 2
;;
--gpu_device_targets)
if [[ "$2" == "--custom_install" ]]; then
GPU_DEVICE_TARGETS=""
shift 2
elif [[ -n "$2" ]]; then
GPU_DEVICE_TARGETS="$2"
shift 2
else
GPU_DEVICE_TARGETS=""
shift 1
fi
;;
*)
POSITIONAL_ARGS+=("$1")
shift
Expand Down Expand Up @@ -164,6 +177,7 @@ fi
--rocm-build-job=$ROCM_BUILD_JOB \
--rocm-build-num=$ROCM_BUILD_NUM \
--compiler=$JAX_COMPILER \
--gpu-device-targets="${GPU_DEVICE_TARGETS}" \
dist_docker \
--dockerfile $DOCKERFILE_PATH \
--image-tag $DOCKER_IMG_NAME
Expand Down
81 changes: 81 additions & 0 deletions build/rocm/test_ci_build.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this. Do we plan on running this in a separate Actions workflow in a future PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about that as I was adding to these, but its mostly just so I can make sure my parser functions actually do what they are supposed to when developing them. i.e. running locally

I think it would be easier to do when we have things move over to the plugin repo.

Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason that we're preferring unittest over pytest? pytest is the newer thing to use

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest extends unittest, so this works with pytest as well.

All unit test things in python extend from unittest since like python 2.7 timeframe, so I just use that base stuff as its the most general, and the fancy features are rarely needed for my test cases.


import importlib.util
import importlib.machinery


def load_ci_build():
spec = importlib.util.spec_from_loader(
"ci_build", importlib.machinery.SourceFileLoader("ci_build", "./ci_build")
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod


ci_build = load_ci_build()


class CIBuildTestCase(unittest.TestCase):
def test_parse_gpu_targets_spaces(self):
targets = ["gfx908", "gfx940", "gfx1201"]
r = ci_build.parse_gpu_targets(" ".join(targets))
self.assertEqual(r, targets)

def test_parse_gpu_targets_commas(self):
targets = ["gfx908", "gfx940", "gfx1201"]
r = ci_build.parse_gpu_targets(",".join(targets))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be convinced otherwise, but it makes the test case easier to interpret if we follow a "one test case, one assertion" rule (unless you need to do multiple checks on the same output, check for None, etc). Could we make the " " and "," cases separate test cases?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I cheated here, but this is the "good path" cases vs the negative path cases which are below, and have specific failure conditions.

I split the difference between one test case for everything vs one test case per assert and ended up with this compromise lol

self.assertEqual(r, targets)

def test_parse_gpu_targets_empty_string(self):
expected = ci_build.DEFAULT_GPU_DEVICE_TARGETS.split(",")
r = ci_build.parse_gpu_targets("")
self.assertEqual(r, expected)

def test_parse_gpu_targets_whitespace_only(self):
self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ")

def test_parse_gpu_targets_invalid_arch(self):
targets = ["gfx908", "gfx940", "--oops", "/jax"]
self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets))

def test_canonicalize_python_versions(self):
versions = ["3.10.0", "3.11.0", "3.12.0"]
exp = ["3.10", "3.11", "3.12"]
res = ci_build.canonicalize_python_versions(versions)
self.assertEqual(res, exp)

def test_canonicalize_python_versions_scalar(self):
versions = ["3.10.0"]
exp = ["3.10"]
res = ci_build.canonicalize_python_versions(versions)
self.assertEqual(res, exp)

def test_canonicalize_python_versions_no_revision_part(self):
versions = ["3.10", "3.11"]
res = ci_build.canonicalize_python_versions(versions)
self.assertEqual(res, versions)

def test_canonicalize_python_versions_string(self):
versions = "3.10.0"
self.assertRaises(ValueError, ci_build.canonicalize_python_versions, versions)


if __name__ == "__main__":
unittest.main()
Loading
Loading