Skip to content

Commit 08b7c08

Browse files
committed
Fix missing defaults for GPU target arch'es (#293)
Move defaults for these up to ci_build python We also add some error catching logic in the python script to clean up weirdness coming from the build scripts. Also clean up some formatting and typos (cherry picked from commit ef1d561)
1 parent ec4b8ee commit 08b7c08

File tree

3 files changed

+132
-28
lines changed

3 files changed

+132
-28
lines changed

build/rocm/ci_build

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ import logging
2525
import os
2626
import subprocess
2727
import sys
28+
from typing import List
29+
30+
31+
DEFAULT_GPU_DEVICE_TARGETS = "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
2832

2933

3034
LOG = logging.getLogger("ci_build")
@@ -37,10 +41,11 @@ def image_by_name(name):
3741
return image_id
3842

3943

40-
def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
44+
def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num, gpu_device_targets: List[str]) -> str:
4145
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(
4246
".", ""
4347
)
48+
4449
cmd = [
4550
"docker",
4651
"build",
@@ -50,6 +55,7 @@ def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
5055
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
5156
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
5257
"--tag=%s" % image_name,
58+
"--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets),
5359
".",
5460
]
5561

@@ -65,11 +71,15 @@ def dist_wheels(
6571
rocm_build_job="",
6672
rocm_build_num="",
6773
compiler="gcc",
74+
gpu_device_targets : List[str] = None,
6875
):
76+
if not gpu_device_targets:
77+
gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",")
78+
6979
# We want to make sure the wheels we build are manylinux compliant. We'll
7080
# do the build in a container. Build the image for this.
7181
image_name = create_manylinux_build_image(
72-
rocm_version, rocm_build_job, rocm_build_num
82+
rocm_version, rocm_build_job, rocm_build_num, gpu_device_targets
7383
)
7484

7585
if xla_path:
@@ -91,6 +101,8 @@ def dist_wheels(
91101
pyver_string,
92102
"--compiler",
93103
compiler,
104+
"--gpu-device-targets",
105+
",".join(gpu_device_targets),
94106
]
95107

96108
if xla_path:
@@ -127,31 +139,6 @@ def dist_wheels(
127139
]
128140
)
129141

130-
# Add command for unit tests
131-
cmd.extend(
132-
[
133-
"&&",
134-
"bazel",
135-
"test",
136-
"-k",
137-
"--jobs=4",
138-
"--test_verbose_timeout_warnings=true",
139-
"--test_output=all",
140-
"--test_summary=detailed",
141-
"--local_test_jobs=1",
142-
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
143-
"--test_env=JAX_SKIP_SLOW_TESTS=0",
144-
"--verbose_failures=true",
145-
"--config=rocm",
146-
"--action_env=ROCM_PATH=/opt/rocm",
147-
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
148-
"--test_tag_filters=-multiaccelerator",
149-
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
150-
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
151-
"//tests:gpu_tests",
152-
]
153-
)
154-
155142
LOG.info("Running: %s", cmd)
156143
_ = subprocess.run(cmd, check=True)
157144

@@ -196,10 +183,14 @@ def dist_docker(
196183
tag="rocm/jax-dev",
197184
dockerfile=None,
198185
keep_image=True,
186+
gpu_device_targets : List[str] = None,
199187
):
200188
if not dockerfile:
201189
dockerfile = "build/rocm/Dockerfile.ms"
202190

191+
if not gpu_device_targets:
192+
gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",")
193+
203194
python_version = python_versions[0]
204195

205196
md = _fetch_jax_metadata(xla_path)
@@ -212,6 +203,7 @@ def dist_docker(
212203
"--target",
213204
"rt_build",
214205
"--build-arg=ROCM_VERSION=%s" % rocm_version,
206+
"--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets),
215207
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
216208
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
217209
"--build-arg=BASE_DOCKER=%s" % base_docker,
@@ -278,6 +270,37 @@ def test(image_name, test_cmd):
278270
subprocess.check_call(cmd)
279271

280272

273+
def parse_gpu_targets(targets_string):
274+
# catch case where targets_string was empty.
275+
# None should already be caught by argparse, but
276+
# it doesn't hurt to check twice
277+
if not targets_string:
278+
targets_string = DEFAULT_GPU_DEVICE_TARGETS
279+
280+
if "," in targets_string:
281+
targets = targets_string.split(",")
282+
elif " " in targets_string:
283+
targets = targets_string.split(" ")
284+
else:
285+
targets = targets_string
286+
287+
res = []
288+
# cleanup and validation
289+
for t in targets:
290+
if not t:
291+
continue
292+
293+
if not t.startswith("gfx"):
294+
raise ValueError("Invalid GPU architecture target: %r" % t)
295+
296+
res.append(t.strip())
297+
298+
if not res:
299+
raise ValueError("GPU_DEVICE_TARGETS cannot be empty")
300+
301+
return res
302+
303+
281304
def parse_args():
282305
p = argparse.ArgumentParser()
283306
p.add_argument(
@@ -289,7 +312,7 @@ def parse_args():
289312
p.add_argument(
290313
"--python-versions",
291314
type=lambda x: x.split(","),
292-
default="3.12",
315+
default=["3.12"],
293316
help="Comma separated list of CPython versions to build wheels for",
294317
)
295318

@@ -321,6 +344,11 @@ def parse_args():
321344
choices=["gcc", "clang"],
322345
help="Compiler backend to use when compiling jax/jaxlib",
323346
)
347+
p.add_argument(
348+
"--gpu-device-targets",
349+
default=DEFAULT_GPU_DEVICE_TARGETS,
350+
help="List of AMDGPU device targets passed from job",
351+
)
324352

325353
subp = p.add_subparsers(dest="action", required=True)
326354

@@ -344,6 +372,7 @@ def parse_args():
344372
def main():
345373
logging.basicConfig(level=logging.INFO)
346374
args = parse_args()
375+
gpu_device_targets = parse_gpu_targets(args.gpu_device_targets)
347376

348377
if args.action == "dist_wheels":
349378
dist_wheels(
@@ -353,6 +382,7 @@ def main():
353382
args.rocm_build_job,
354383
args.rocm_build_num,
355384
compiler=args.compiler,
385+
gpu_device_targets=gpu_device_targets,
356386
)
357387

358388
elif args.action == "test":
@@ -366,6 +396,7 @@ def main():
366396
args.rocm_build_job,
367397
args.rocm_build_num,
368398
compiler=args.compiler,
399+
gpu_device_targets=gpu_device_targets,
369400
)
370401
dist_docker(
371402
args.rocm_version,
@@ -377,6 +408,7 @@ def main():
377408
tag=args.image_tag,
378409
dockerfile=args.dockerfile,
379410
keep_image=args.keep_image,
411+
gpu_device_targets=gpu_device_targets,
380412
)
381413

382414

build/rocm/ci_build.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ ROCM_BUILD_NUM=""
5151
BASE_DOCKER="ubuntu:22.04"
5252
CUSTOM_INSTALL=""
5353
JAX_USE_CLANG=""
54+
GPU_DEVICE_TARGETS=""
5455
POSITIONAL_ARGS=()
5556

5657
RUNTIME_FLAG=0
@@ -98,6 +99,18 @@ while [[ $# -gt 0 ]]; do
9899
JAX_USE_CLANG="$2"
99100
shift 2
100101
;;
102+
--gpu_device_targets)
103+
if [[ "$2" == "--custom_install" ]]; then
104+
GPU_DEVICE_TARGETS=""
105+
shift 2
106+
elif [[ -n "$2" ]]; then
107+
GPU_DEVICE_TARGETS="$2"
108+
shift 2
109+
else
110+
GPU_DEVICE_TARGETS=""
111+
shift 1
112+
fi
113+
;;
101114
*)
102115
POSITIONAL_ARGS+=("$1")
103116
shift
@@ -164,6 +177,7 @@ fi
164177
--rocm-build-job=$ROCM_BUILD_JOB \
165178
--rocm-build-num=$ROCM_BUILD_NUM \
166179
--compiler=$JAX_COMPILER \
180+
--gpu-device-targets="${GPU_DEVICE_TARGETS}" \
167181
dist_docker \
168182
--dockerfile $DOCKERFILE_PATH \
169183
--image-tag $DOCKER_IMG_NAME

build/rocm/test_ci_build.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2024 The JAX Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import unittest
18+
19+
import importlib.util
20+
import importlib.machinery
21+
22+
23+
def load_ci_build():
24+
spec = importlib.util.spec_from_loader(
25+
"ci_build", importlib.machinery.SourceFileLoader("ci_build", "./ci_build")
26+
)
27+
mod = importlib.util.module_from_spec(spec)
28+
spec.loader.exec_module(mod)
29+
return mod
30+
31+
32+
ci_build = load_ci_build()
33+
34+
35+
class CIBuildTestCase(unittest.TestCase):
36+
def test_parse_gpu_targets(self):
37+
targets = ["gfx908", "gfx940", "gfx1201"]
38+
39+
r = ci_build.parse_gpu_targets(" ".join(targets))
40+
self.assertEqual(r, targets)
41+
42+
r = ci_build.parse_gpu_targets(",".join(targets))
43+
self.assertEqual(r, targets)
44+
45+
def test_parse_gpu_targets_empty_string(self):
46+
expected = ci_build.DEFAULT_GPU_DEVICE_TARGETS.split(",")
47+
r = ci_build.parse_gpu_targets("")
48+
self.assertEqual(r, expected)
49+
50+
self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ")
51+
52+
def test_parse_gpu_targets_invalid_arch(self):
53+
targets = ["gfx908", "gfx940", "--oops", "/jax"]
54+
self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets))
55+
56+
57+
if __name__ == "__main__":
58+
unittest.main()

0 commit comments

Comments
 (0)