Skip to content

Commit 16dc0ad

Browse files
Add jax_source_package macros and target to generate a source package .tar.gz.
Refactor `jax_wheel` macros, so it outputs a `.whl` file only. When the macros returns one output object only, it allows all downstream dependencies consume it easily without the need to filter the macros outputs. The previous implementation design (when `jax_wheel` returned `.tar.gz` and `.whl` files) required one of two options: either create a new target that produces `.whl` only, or to implement filename filtering in the downstream rules. With the new implementation we can just depend on `//:jax_wheel` target that produces the `.whl`. PiperOrigin-RevId: 738547491
1 parent 29e90a3 commit 16dc0ad

File tree

5 files changed

+110
-36
lines changed

5 files changed

+110
-36
lines changed

BUILD.bazel

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
1616
load(
1717
"//jaxlib:jax.bzl",
18+
"jax_source_package",
1819
"jax_wheel",
1920
)
2021

@@ -67,7 +68,6 @@ py_binary(
6768

6869
jax_wheel(
6970
name = "jax_wheel",
70-
build_wheel_only = False,
7171
platform_independent = True,
7272
source_files = [
7373
":transitive_py_data",
@@ -82,3 +82,19 @@ jax_wheel(
8282
wheel_binary = ":build_wheel",
8383
wheel_name = "jax",
8484
)
85+
86+
jax_source_package(
87+
name = "jax_source_package",
88+
source_files = [
89+
":transitive_py_data",
90+
":transitive_py_deps",
91+
"//jax:py.typed",
92+
"AUTHORS",
93+
"LICENSE",
94+
"README.md",
95+
"pyproject.toml",
96+
"setup.py",
97+
],
98+
source_package_binary = ":build_wheel",
99+
source_package_name = "jax",
100+
)

build/build.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
# rule as the default.
6969
WHEEL_BUILD_TARGET_DICT_NEW = {
7070
"jax": "//:jax_wheel",
71+
"jax_source_package": "//:jax_source_package",
7172
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
7273
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
7374
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
@@ -661,6 +662,8 @@ async def main():
661662
# Append the build target to the Bazel command.
662663
build_target = wheel_build_targets[wheel]
663664
wheel_build_command.append(build_target)
665+
if args.use_new_wheel_build_rule and wheel == "jax":
666+
wheel_build_command.append(wheel_build_targets["jax_source_package"])
664667

665668
if not args.use_new_wheel_build_rule:
666669
wheel_build_command.append("--")

build_wheel.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@
4747
parser.add_argument(
4848
"--srcs", help="source files for the wheel", action="append"
4949
)
50+
parser.add_argument(
51+
"--build-wheel-only",
52+
default=False,
53+
help=(
54+
"Whether to build the wheel only. Optional."
55+
),
56+
)
57+
parser.add_argument(
58+
"--build-source-package-only",
59+
default=False,
60+
help=(
61+
"Whether to build the source package only. Optional."
62+
),
63+
)
5064
args = parser.parse_args()
5165

5266

@@ -94,7 +108,8 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
94108
args.output_path,
95109
package_name="jax",
96110
git_hash=args.jaxlib_git_hash,
97-
build_wheel_only=False,
111+
build_wheel_only=args.build_wheel_only,
112+
build_source_package_only=args.build_source_package_only,
98113
)
99114
finally:
100115
if tmpdir:

jaxlib/jax.bzl

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def _get_full_wheel_name(
362362
free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "",
363363
)
364364

365-
def _get_source_distribution_name(package_name, wheel_version):
365+
def _get_source_package_name(package_name, wheel_version):
366366
return "{package_name}-{wheel_version}.tar.gz".format(
367367
package_name = package_name,
368368
wheel_version = wheel_version,
@@ -394,37 +394,47 @@ def _jax_wheel_impl(ctx):
394394
no_abi = ctx.attr.no_abi
395395
platform_independent = ctx.attr.platform_independent
396396
build_wheel_only = ctx.attr.build_wheel_only
397+
build_source_package_only = ctx.attr.build_source_package_only
397398
editable = ctx.attr.editable
398399
platform_name = ctx.attr.platform_name
400+
401+
output_dir_path = ""
402+
outputs = []
399403
if editable:
400404
output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name)
401-
wheel_dir = output_dir.path
405+
output_dir_path = output_dir.path
402406
outputs = [output_dir]
403407
args.add("--editable")
404408
else:
405-
wheel_name = _get_full_wheel_name(
406-
package_name = ctx.attr.wheel_name,
407-
no_abi = no_abi,
408-
platform_independent = platform_independent,
409-
platform_name = platform_name,
410-
cpu_name = cpu,
411-
wheel_version = full_wheel_version,
412-
py_freethreaded = py_freethreaded,
413-
)
414-
wheel_file = ctx.actions.declare_file(output_path +
415-
"/" + wheel_name)
416-
wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")]
417-
outputs = [wheel_file]
418-
if not build_wheel_only:
419-
source_distribution_name = _get_source_distribution_name(
409+
if build_wheel_only:
410+
wheel_name = _get_full_wheel_name(
420411
package_name = ctx.attr.wheel_name,
412+
no_abi = no_abi,
413+
platform_independent = platform_independent,
414+
platform_name = platform_name,
415+
cpu_name = cpu,
421416
wheel_version = full_wheel_version,
417+
py_freethreaded = py_freethreaded,
422418
)
423-
source_distribution_file = ctx.actions.declare_file(output_path +
424-
"/" + source_distribution_name)
425-
outputs.append(source_distribution_file)
426-
427-
args.add("--output_path", wheel_dir) # required argument
419+
wheel_file = ctx.actions.declare_file(output_path +
420+
"/" + wheel_name)
421+
output_dir_path = wheel_file.path[:wheel_file.path.rfind("/")]
422+
outputs = [wheel_file]
423+
if ctx.attr.wheel_name == "jax":
424+
args.add("--build-wheel-only", "True")
425+
if build_source_package_only:
426+
source_package_name = _get_source_package_name(
427+
package_name = ctx.attr.wheel_name,
428+
wheel_version = full_wheel_version,
429+
)
430+
source_package_file = ctx.actions.declare_file(output_path +
431+
"/" + source_package_name)
432+
output_dir_path = source_package_file.path[:source_package_file.path.rfind("/")]
433+
outputs = [source_package_file]
434+
if ctx.attr.wheel_name == "jax":
435+
args.add("--build-source-package-only", "True")
436+
437+
args.add("--output_path", output_dir_path) # required argument
428438
if not platform_independent:
429439
args.add("--cpu", cpu)
430440
args.add("--jaxlib_git_hash", git_hash) # required argument
@@ -472,16 +482,17 @@ _jax_wheel = rule(
472482
"wheel_name": attr.string(mandatory = True),
473483
"no_abi": attr.bool(default = False),
474484
"platform_independent": attr.bool(default = False),
475-
"build_wheel_only": attr.bool(default = True),
485+
"build_wheel_only": attr.bool(mandatory = True, default = True),
486+
"build_source_package_only": attr.bool(mandatory = True, default = False),
476487
"editable": attr.bool(default = False),
477-
"cpu": attr.string(mandatory = True),
478-
"platform_name": attr.string(mandatory = True),
488+
"cpu": attr.string(),
489+
"platform_name": attr.string(),
479490
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
480491
"source_files": attr.label_list(allow_files = True),
481492
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
482493
"enable_cuda": attr.bool(default = False),
483494
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
484-
"platform_version": attr.string(mandatory = True, default = ""),
495+
"platform_version": attr.string(),
485496
"skip_gpu_kernels": attr.bool(default = False),
486497
"enable_rocm": attr.bool(default = False),
487498
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
@@ -498,7 +509,6 @@ def jax_wheel(
498509
wheel_name,
499510
no_abi = False,
500511
platform_independent = False,
501-
build_wheel_only = True,
502512
editable = False,
503513
enable_cuda = False,
504514
enable_rocm = False,
@@ -509,11 +519,10 @@ def jax_wheel(
509519
Common artifact attributes are grouped within a single macro.
510520
511521
Args:
512-
name: the name of the wheel
522+
name: the target name
513523
wheel_binary: the binary to use to build the wheel
514524
wheel_name: the name of the wheel
515525
no_abi: whether to build a wheel without ABI
516-
build_wheel_only: whether to build a wheel without source distribution
517526
editable: whether to build an editable wheel
518527
platform_independent: whether to build a wheel without platform tag
519528
enable_cuda: whether to build a cuda wheel
@@ -522,15 +531,16 @@ def jax_wheel(
522531
source_files: the source files to include in the wheel
523532
524533
Returns:
525-
A directory containing the wheel
534+
A wheel file or a wheel directory.
526535
"""
527536
_jax_wheel(
528537
name = name,
529538
wheel_binary = wheel_binary,
530539
wheel_name = wheel_name,
531540
no_abi = no_abi,
532541
platform_independent = platform_independent,
533-
build_wheel_only = build_wheel_only,
542+
build_wheel_only = True,
543+
build_source_package_only = False,
534544
editable = editable,
535545
enable_cuda = enable_cuda,
536546
enable_rocm = enable_rocm,
@@ -554,6 +564,34 @@ def jax_wheel(
554564
source_files = source_files,
555565
)
556566

567+
def jax_source_package(
568+
name,
569+
source_package_binary,
570+
source_package_name,
571+
source_files = []):
572+
"""Create jax source package.
573+
574+
Common artifact attributes are grouped within a single macro.
575+
576+
Args:
577+
name: the target name
578+
source_package_binary: the binary to use to build the package
579+
source_package_name: the name of the source package
580+
source_files: the source files to include in the package
581+
582+
Returns:
583+
A jax source package file.
584+
"""
585+
_jax_wheel(
586+
name = name,
587+
wheel_binary = source_package_binary,
588+
wheel_name = source_package_name,
589+
build_source_package_only = True,
590+
build_wheel_only = False,
591+
platform_independent = True,
592+
source_files = source_files,
593+
)
594+
557595
jax_test_file_visibility = []
558596

559597
jax_export_file_visibility = []

jaxlib/tools/build_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def build_wheel(
6565
package_name: str,
6666
git_hash: str = "",
6767
build_wheel_only: bool = True,
68+
build_source_package_only: bool = False,
6869
) -> None:
6970
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
7071
env = dict(os.environ)
@@ -78,7 +79,8 @@ def build_wheel(
7879
env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:")
7980
subprocess.run(
8081
[sys.executable, "-m", "build", "-n"]
81-
+ (["-w"] if build_wheel_only else []),
82+
+ (["-w"] if build_wheel_only else [])
83+
+ (["-s"] if build_source_package_only else []),
8284
check=True,
8385
cwd=sources_path,
8486
env=env,
@@ -97,10 +99,10 @@ def build_wheel(
9799
sys.stderr.write(" bazel run //build:requirements.update" +
98100
f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n")
99101
shutil.copy(wheel, output_path)
100-
if not build_wheel_only:
102+
if build_source_package_only:
101103
for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")):
102104
output_file = os.path.join(output_path, os.path.basename(dist))
103-
sys.stderr.write(f"Output source distribution: {output_file}\n\n")
105+
sys.stderr.write(f"Output source package: {output_file}\n\n")
104106
shutil.copy(dist, output_path)
105107

106108

0 commit comments

Comments
 (0)