@@ -362,7 +362,7 @@ def _get_full_wheel_name(
362362 free_threaded_suffix = "t" if py_freethreaded .lower () == "yes" else "" ,
363363 )
364364
365- def _get_source_distribution_name (package_name , wheel_version ):
365+ def _get_source_package_name (package_name , wheel_version ):
366366 return "{package_name}-{wheel_version}.tar.gz" .format (
367367 package_name = package_name ,
368368 wheel_version = wheel_version ,
@@ -394,37 +394,47 @@ def _jax_wheel_impl(ctx):
394394 no_abi = ctx .attr .no_abi
395395 platform_independent = ctx .attr .platform_independent
396396 build_wheel_only = ctx .attr .build_wheel_only
397+ build_source_package_only = ctx .attr .build_source_package_only
397398 editable = ctx .attr .editable
398399 platform_name = ctx .attr .platform_name
400+
401+ output_dir_path = ""
402+ outputs = []
399403 if editable :
400404 output_dir = ctx .actions .declare_directory (output_path + "/" + ctx .attr .wheel_name )
401- wheel_dir = output_dir .path
405+ output_dir_path = output_dir .path
402406 outputs = [output_dir ]
403407 args .add ("--editable" )
404408 else :
405- wheel_name = _get_full_wheel_name (
406- package_name = ctx .attr .wheel_name ,
407- no_abi = no_abi ,
408- platform_independent = platform_independent ,
409- platform_name = platform_name ,
410- cpu_name = cpu ,
411- wheel_version = full_wheel_version ,
412- py_freethreaded = py_freethreaded ,
413- )
414- wheel_file = ctx .actions .declare_file (output_path +
415- "/" + wheel_name )
416- wheel_dir = wheel_file .path [:wheel_file .path .rfind ("/" )]
417- outputs = [wheel_file ]
418- if not build_wheel_only :
419- source_distribution_name = _get_source_distribution_name (
409+ if build_wheel_only :
410+ wheel_name = _get_full_wheel_name (
420411 package_name = ctx .attr .wheel_name ,
412+ no_abi = no_abi ,
413+ platform_independent = platform_independent ,
414+ platform_name = platform_name ,
415+ cpu_name = cpu ,
421416 wheel_version = full_wheel_version ,
417+ py_freethreaded = py_freethreaded ,
422418 )
423- source_distribution_file = ctx .actions .declare_file (output_path +
424- "/" + source_distribution_name )
425- outputs .append (source_distribution_file )
426-
427- args .add ("--output_path" , wheel_dir ) # required argument
419+ wheel_file = ctx .actions .declare_file (output_path +
420+ "/" + wheel_name )
421+ output_dir_path = wheel_file .path [:wheel_file .path .rfind ("/" )]
422+ outputs = [wheel_file ]
423+ if ctx .attr .wheel_name == "jax" :
424+ args .add ("--build-wheel-only" , "True" )
425+ if build_source_package_only :
426+ source_package_name = _get_source_package_name (
427+ package_name = ctx .attr .wheel_name ,
428+ wheel_version = full_wheel_version ,
429+ )
430+ source_package_file = ctx .actions .declare_file (output_path +
431+ "/" + source_package_name )
432+ output_dir_path = source_package_file .path [:source_package_file .path .rfind ("/" )]
433+ outputs = [source_package_file ]
434+ if ctx .attr .wheel_name == "jax" :
435+ args .add ("--build-source-package-only" , "True" )
436+
437+ args .add ("--output_path" , output_dir_path ) # required argument
428438 if not platform_independent :
429439 args .add ("--cpu" , cpu )
430440 args .add ("--jaxlib_git_hash" , git_hash ) # required argument
@@ -472,16 +482,17 @@ _jax_wheel = rule(
472482 "wheel_name" : attr .string (mandatory = True ),
473483 "no_abi" : attr .bool (default = False ),
474484 "platform_independent" : attr .bool (default = False ),
475- "build_wheel_only" : attr .bool (default = True ),
485+ "build_wheel_only" : attr .bool (mandatory = True , default = True ),
486+ "build_source_package_only" : attr .bool (mandatory = True , default = False ),
476487 "editable" : attr .bool (default = False ),
477- "cpu" : attr .string (mandatory = True ),
478- "platform_name" : attr .string (mandatory = True ),
488+ "cpu" : attr .string (),
489+ "platform_name" : attr .string (),
479490 "git_hash" : attr .label (default = Label ("//jaxlib/tools:jaxlib_git_hash" )),
480491 "source_files" : attr .label_list (allow_files = True ),
481492 "output_path" : attr .label (default = Label ("//jaxlib/tools:output_path" )),
482493 "enable_cuda" : attr .bool (default = False ),
483494 # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
484- "platform_version" : attr .string (mandatory = True , default = "" ),
495+ "platform_version" : attr .string (),
485496 "skip_gpu_kernels" : attr .bool (default = False ),
486497 "enable_rocm" : attr .bool (default = False ),
487498 "include_cuda_libs" : attr .label (default = Label ("@local_config_cuda//cuda:include_cuda_libs" )),
@@ -498,7 +509,6 @@ def jax_wheel(
498509 wheel_name ,
499510 no_abi = False ,
500511 platform_independent = False ,
501- build_wheel_only = True ,
502512 editable = False ,
503513 enable_cuda = False ,
504514 enable_rocm = False ,
@@ -509,11 +519,10 @@ def jax_wheel(
509519 Common artifact attributes are grouped within a single macro.
510520
511521 Args:
512- name: the name of the wheel
522+ name: the target name
513523 wheel_binary: the binary to use to build the wheel
514524 wheel_name: the name of the wheel
515525 no_abi: whether to build a wheel without ABI
516- build_wheel_only: whether to build a wheel without source distribution
517526 editable: whether to build an editable wheel
518527 platform_independent: whether to build a wheel without platform tag
519528 enable_cuda: whether to build a cuda wheel
@@ -522,15 +531,16 @@ def jax_wheel(
522531 source_files: the source files to include in the wheel
523532
524533 Returns:
525- A directory containing the wheel
534+ A wheel file or a wheel directory.
526535 """
527536 _jax_wheel (
528537 name = name ,
529538 wheel_binary = wheel_binary ,
530539 wheel_name = wheel_name ,
531540 no_abi = no_abi ,
532541 platform_independent = platform_independent ,
533- build_wheel_only = build_wheel_only ,
542+ build_wheel_only = True ,
543+ build_source_package_only = False ,
534544 editable = editable ,
535545 enable_cuda = enable_cuda ,
536546 enable_rocm = enable_rocm ,
@@ -554,6 +564,34 @@ def jax_wheel(
554564 source_files = source_files ,
555565 )
556566
567+ def jax_source_package (
568+ name ,
569+ source_package_binary ,
570+ source_package_name ,
571+ source_files = []):
572+ """Create jax source package.
573+
574+ Common artifact attributes are grouped within a single macro.
575+
576+ Args:
577+ name: the target name
578+ source_package_binary: the binary to use to build the package
579+ source_package_name: the name of the source package
580+ source_files: the source files to include in the package
581+
582+ Returns:
583+ A jax source package file.
584+ """
585+ _jax_wheel (
586+ name = name ,
587+ wheel_binary = source_package_binary ,
588+ wheel_name = source_package_name ,
589+ build_source_package_only = True ,
590+ build_wheel_only = False ,
591+ platform_independent = True ,
592+ source_files = source_files ,
593+ )
594+
557595jax_test_file_visibility = []
558596
559597jax_export_file_visibility = []
0 commit comments