@@ -28,7 +28,9 @@ import sys
2828from typing import List
2929
3030
31- DEFAULT_GPU_DEVICE_TARGETS = "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
31+ DEFAULT_GPU_DEVICE_TARGETS = (
32+ "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
33+ )
3234
3335
3436LOG = logging .getLogger ("ci_build" )
@@ -41,7 +43,9 @@ def image_by_name(name):
4143 return image_id
4244
4345
44- def create_manylinux_build_image (rocm_version , rocm_build_job , rocm_build_num , gpu_device_targets : List [str ]) -> str :
46+ def create_manylinux_build_image (
47+ rocm_version , rocm_build_job , rocm_build_num , gpu_device_targets : List [str ]
48+ ) -> str :
4549 image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version .replace (
4650 "." , ""
4751 )
@@ -71,7 +75,7 @@ def dist_wheels(
7175 rocm_build_job = "" ,
7276 rocm_build_num = "" ,
7377 compiler = "gcc" ,
74- gpu_device_targets : List [str ] = None ,
78+ gpu_device_targets : List [str ] = None ,
7579):
7680 if not gpu_device_targets :
7781 gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS .split ("," )
@@ -183,7 +187,7 @@ def dist_docker(
183187 tag = "rocm/jax-dev" ,
184188 dockerfile = None ,
185189 keep_image = True ,
186- gpu_device_targets : List [str ] = None ,
190+ gpu_device_targets : List [str ] = None ,
187191):
188192 if not dockerfile :
189193 dockerfile = "build/rocm/Dockerfile.ms"
@@ -270,6 +274,24 @@ def test(image_name, test_cmd):
270274 subprocess .check_call (cmd )
271275
272276
277+ def canonicalize_python_versions (versions : List [str ]):
278+ if isinstance (versions , str ):
279+ raise ValueError ("'versions' must be a list of strings: versions=%r" % versions )
280+
281+ cleaned = []
282+ for v in versions :
283+ tup = v .split ("." )
284+ major = tup [0 ]
285+ minor = tup [1 ]
286+ rev = None
287+ if tup [2 ]:
288+ rev = tup [2 ]
289+
290+ cleaned .append ("%s.%s" % (major , minor ))
291+
292+ return cleaned
293+
294+
273295def parse_gpu_targets (targets_string ):
274296 # catch case where targets_string was empty.
275297 # None should already be caught by argparse, but
@@ -373,11 +395,12 @@ def main():
373395 logging .basicConfig (level = logging .INFO )
374396 args = parse_args ()
375397 gpu_device_targets = parse_gpu_targets (args .gpu_device_targets )
398+ python_versions = canonicalize_python_versions (args .python_versions )
376399
377400 if args .action == "dist_wheels" :
378401 dist_wheels (
379402 args .rocm_version ,
380- args . python_versions ,
403+ python_versions ,
381404 args .xla_source_dir ,
382405 args .rocm_build_job ,
383406 args .rocm_build_num ,
@@ -391,7 +414,7 @@ def main():
391414 elif args .action == "dist_docker" :
392415 dist_wheels (
393416 args .rocm_version ,
394- args . python_versions ,
417+ python_versions ,
395418 args .xla_source_dir ,
396419 gpu_device_targets ,
397420 args .rocm_build_job ,
@@ -402,7 +425,7 @@ def main():
402425 dist_docker (
403426 args .rocm_version ,
404427 args .base_docker ,
405- args . python_versions ,
428+ python_versions ,
406429 args .xla_source_dir ,
407430 rocm_build_job = args .rocm_build_job ,
408431 rocm_build_num = args .rocm_build_num ,
@@ -415,4 +438,3 @@ def main():
415438
416439if __name__ == "__main__" :
417440 main ()
418-
0 commit comments