@@ -25,6 +25,10 @@ import logging
2525import os
2626import subprocess
2727import sys
28+ from typing import List
29+
30+
31+ DEFAULT_GPU_DEVICE_TARGETS = "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
2832
2933
3034LOG = logging .getLogger ("ci_build" )
@@ -37,10 +41,11 @@ def image_by_name(name):
3741 return image_id
3842
3943
40- def create_manylinux_build_image (rocm_version , rocm_build_job , rocm_build_num ) :
44+ def create_manylinux_build_image (rocm_version , rocm_build_job , rocm_build_num , gpu_device_targets : List [ str ]) -> str :
4145 image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version .replace (
4246 "." , ""
4347 )
48+
4449 cmd = [
4550 "docker" ,
4651 "build" ,
@@ -50,6 +55,7 @@ def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
5055 "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job ,
5156 "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num ,
5257 "--tag=%s" % image_name ,
58+ "--build-arg=GPU_DEVICE_TARGETS=%s" % " " .join (gpu_device_targets ),
5359 "." ,
5460 ]
5561
@@ -65,11 +71,15 @@ def dist_wheels(
6571 rocm_build_job = "" ,
6672 rocm_build_num = "" ,
6773 compiler = "gcc" ,
74+ gpu_device_targets : List [str ] = None ,
6875):
76+ if not gpu_device_targets :
77+ gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS .split ("," )
78+
6979 # We want to make sure the wheels we build are manylinux compliant. We'll
7080 # do the build in a container. Build the image for this.
7181 image_name = create_manylinux_build_image (
72- rocm_version , rocm_build_job , rocm_build_num
82+ rocm_version , rocm_build_job , rocm_build_num , gpu_device_targets
7383 )
7484
7585 if xla_path :
@@ -91,6 +101,8 @@ def dist_wheels(
91101 pyver_string ,
92102 "--compiler" ,
93103 compiler ,
104+ "--gpu-device-targets" ,
105+ "," .join (gpu_device_targets ),
94106 ]
95107
96108 if xla_path :
@@ -127,31 +139,6 @@ def dist_wheels(
127139 ]
128140 )
129141
130- # Add command for unit tests
131- cmd .extend (
132- [
133- "&&" ,
134- "bazel" ,
135- "test" ,
136- "-k" ,
137- "--jobs=4" ,
138- "--test_verbose_timeout_warnings=true" ,
139- "--test_output=all" ,
140- "--test_summary=detailed" ,
141- "--local_test_jobs=1" ,
142- "--test_env=JAX_ACCELERATOR_COUNT=%i" % 4 ,
143- "--test_env=JAX_SKIP_SLOW_TESTS=0" ,
144- "--verbose_failures=true" ,
145- "--config=rocm" ,
146- "--action_env=ROCM_PATH=/opt/rocm" ,
147- "--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a" ,
148- "--test_tag_filters=-multiaccelerator" ,
149- "--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform" ,
150- "--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow" ,
151- "//tests:gpu_tests" ,
152- ]
153- )
154-
155142 LOG .info ("Running: %s" , cmd )
156143 _ = subprocess .run (cmd , check = True )
157144
@@ -196,10 +183,14 @@ def dist_docker(
196183 tag = "rocm/jax-dev" ,
197184 dockerfile = None ,
198185 keep_image = True ,
186+ gpu_device_targets : List [str ] = None ,
199187):
200188 if not dockerfile :
201189 dockerfile = "build/rocm/Dockerfile.ms"
202190
191+ if not gpu_device_targets :
192+ gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS .split ("," )
193+
203194 python_version = python_versions [0 ]
204195
205196 md = _fetch_jax_metadata (xla_path )
@@ -212,6 +203,7 @@ def dist_docker(
212203 "--target" ,
213204 "rt_build" ,
214205 "--build-arg=ROCM_VERSION=%s" % rocm_version ,
206+ "--build-arg=GPU_DEVICE_TARGETS=%s" % " " .join (gpu_device_targets ),
215207 "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job ,
216208 "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num ,
217209 "--build-arg=BASE_DOCKER=%s" % base_docker ,
@@ -278,6 +270,37 @@ def test(image_name, test_cmd):
278270 subprocess .check_call (cmd )
279271
280272
273+ def parse_gpu_targets (targets_string ):
274+ # catch case where targets_string was empty.
275+ # None should already be caught by argparse, but
276+ # it doesn't hurt to check twice
277+ if not targets_string :
278+ targets_string = DEFAULT_GPU_DEVICE_TARGETS
279+
280+ if "," in targets_string :
281+ targets = targets_string .split ("," )
282+ elif " " in targets_string :
283+ targets = targets_string .split (" " )
284+ else :
285+ targets = targets_string
286+
287+ res = []
288+ # cleanup and validation
289+ for t in targets :
290+ if not t :
291+ continue
292+
293+ if not t .startswith ("gfx" ):
294+ raise ValueError ("Invalid GPU architecture target: %r" % t )
295+
296+ res .append (t .strip ())
297+
298+ if not res :
299+ raise ValueError ("GPU_DEVICE_TARGETS cannot be empty" )
300+
301+ return res
302+
303+
281304def parse_args ():
282305 p = argparse .ArgumentParser ()
283306 p .add_argument (
@@ -289,7 +312,7 @@ def parse_args():
289312 p .add_argument (
290313 "--python-versions" ,
291314 type = lambda x : x .split ("," ),
292- default = "3.12" ,
315+ default = [ "3.12" ] ,
293316 help = "Comma separated list of CPython versions to build wheels for" ,
294317 )
295318
@@ -321,6 +344,11 @@ def parse_args():
321344 choices = ["gcc" , "clang" ],
322345 help = "Compiler backend to use when compiling jax/jaxlib" ,
323346 )
347+ p .add_argument (
348+ "--gpu-device-targets" ,
349+ default = DEFAULT_GPU_DEVICE_TARGETS ,
350+ help = "List of AMDGPU device targets passed from job" ,
351+ )
324352
325353 subp = p .add_subparsers (dest = "action" , required = True )
326354
@@ -344,6 +372,7 @@ def parse_args():
344372def main ():
345373 logging .basicConfig (level = logging .INFO )
346374 args = parse_args ()
375+ gpu_device_targets = parse_gpu_targets (args .gpu_device_targets )
347376
348377 if args .action == "dist_wheels" :
349378 dist_wheels (
@@ -353,6 +382,7 @@ def main():
353382 args .rocm_build_job ,
354383 args .rocm_build_num ,
355384 compiler = args .compiler ,
385+ gpu_device_targets = gpu_device_targets ,
356386 )
357387
358388 elif args .action == "test" :
@@ -363,9 +393,11 @@ def main():
363393 args .rocm_version ,
364394 args .python_versions ,
365395 args .xla_source_dir ,
396+ gpu_device_targets ,
366397 args .rocm_build_job ,
367398 args .rocm_build_num ,
368399 compiler = args .compiler ,
400+ gpu_device_targets = gpu_device_targets ,
369401 )
370402 dist_docker (
371403 args .rocm_version ,
@@ -377,6 +409,7 @@ def main():
377409 tag = args .image_tag ,
378410 dockerfile = args .dockerfile ,
379411 keep_image = args .keep_image ,
412+ gpu_device_targets = gpu_device_targets ,
380413 )
381414
382415
0 commit comments