Skip to content

Commit eb912ad

Browse files
Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in jax-ml#25126) Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file). You can still build the `jax` wheel with `python3 -m build` command. Bazel `jax` wheel target: `//:jax_wheel` Environment variables combinations for creating wheels with different versions: * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * release: `--repo_env=ML_WHEEL_TYPE=release` * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1` * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 730916743
1 parent 7a162f2 commit eb912ad

File tree

18 files changed

+453
-20
lines changed

18 files changed

+453
-20
lines changed

BUILD

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
16+
load(
17+
"//jaxlib:jax.bzl",
18+
"jax_wheel",
19+
)
20+
21+
collect_data_files(
22+
name = "transitive_py_data",
23+
deps = ["//jax"],
24+
)
25+
26+
transitive_py_deps(
27+
name = "transitive_py_deps",
28+
deps = [
29+
"//jax",
30+
"//jax:compilation_cache",
31+
"//jax:experimental",
32+
"//jax:experimental_colocated_python",
33+
"//jax:experimental_sparse",
34+
"//jax:internal_export_back_compat_test_util",
35+
"//jax:internal_test_harnesses",
36+
"//jax:internal_test_util",
37+
"//jax:lax_reference",
38+
"//jax:pallas_experimental_gpu_ops",
39+
"//jax:pallas_gpu_ops",
40+
"//jax:pallas_mosaic_gpu",
41+
"//jax:pallas_tpu_ops",
42+
"//jax:pallas_triton",
43+
"//jax:source_mapper",
44+
"//jax:sparse_test_util",
45+
"//jax:test_util",
46+
"//jax/_src/lib",
47+
"//jax/_src/pallas/mosaic_gpu",
48+
"//jax/experimental/array_serialization:serialization",
49+
"//jax/experimental/jax2tf",
50+
"//jax/extend",
51+
"//jax/extend:ifrt_programs",
52+
"//jax/extend/mlir",
53+
"//jax/extend/mlir/dialects",
54+
"//jax/tools:colab_tpu",
55+
"//jax/tools:jax_to_ir",
56+
"//jax/tools:pgo_nsys_converter",
57+
],
58+
)
59+
60+
py_binary(
61+
name = "build_wheel",
62+
srcs = ["build_wheel.py"],
63+
deps = [
64+
"//jaxlib/tools:build_utils",
65+
"@pypi_build//:pkg",
66+
"@pypi_setuptools//:pkg",
67+
"@pypi_wheel//:pkg",
68+
],
69+
)
70+
71+
jax_wheel(
72+
name = "jax_wheel",
73+
platform_independent = True,
74+
source_files = [
75+
":transitive_py_data",
76+
":transitive_py_deps",
77+
"//jax:py.typed",
78+
"AUTHORS",
79+
"LICENSE",
80+
"README.md",
81+
"pyproject.toml",
82+
"setup.py",
83+
],
84+
wheel_binary = ":build_wheel",
85+
wheel_name = "jax",
86+
)

build/requirements_lock_3_10.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,29 @@ sortedcontainers==2.4.0 \
566566
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
567567
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
568568
# via hypothesis
569+
tensorstore==0.1.72 \
570+
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
571+
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
572+
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
573+
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
574+
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
575+
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
576+
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
577+
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
578+
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
579+
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
580+
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
581+
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
582+
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
583+
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
584+
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
585+
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
586+
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
587+
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
588+
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
589+
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
590+
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
591+
# via -r build/test-requirements.txt
569592
tomli==2.0.1 \
570593
--hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \
571594
--hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f

build/requirements_lock_3_11.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \
561561
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
562562
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
563563
# via hypothesis
564+
tensorstore==0.1.72 \
565+
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
566+
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
567+
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
568+
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
569+
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
570+
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
571+
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
572+
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
573+
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
574+
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
575+
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
576+
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
577+
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
578+
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
579+
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
580+
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
581+
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
582+
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
583+
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
584+
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
585+
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
586+
# via -r build/test-requirements.txt
564587
typing-extensions==4.12.0rc1 \
565588
--hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \
566589
--hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe

build/requirements_lock_3_12.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \
561561
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
562562
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
563563
# via hypothesis
564+
tensorstore==0.1.72 \
565+
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
566+
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
567+
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
568+
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
569+
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
570+
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
571+
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
572+
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
573+
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
574+
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
575+
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
576+
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
577+
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
578+
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
579+
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
580+
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
581+
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
582+
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
583+
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
584+
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
585+
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
586+
# via -r build/test-requirements.txt
564587
typing-extensions==4.12.0rc1 \
565588
--hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \
566589
--hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe

build/requirements_lock_3_13.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,29 @@ sortedcontainers==2.4.0 \
634634
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
635635
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
636636
# via hypothesis
637+
tensorstore==0.1.72 \
638+
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
639+
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
640+
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
641+
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
642+
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
643+
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
644+
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
645+
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
646+
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
647+
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
648+
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
649+
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
650+
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
651+
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
652+
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
653+
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
654+
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
655+
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
656+
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
657+
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
658+
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
659+
# via -r build/test-requirements.txt
637660
typing-extensions==4.12.2 \
638661
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
639662
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8

build/requirements_lock_3_13_ft.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,29 @@ sortedcontainers==2.4.0 \
588588
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
589589
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
590590
# via hypothesis
591+
tensorstore==0.1.72 \
592+
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
593+
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
594+
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
595+
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
596+
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
597+
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
598+
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
599+
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
600+
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
601+
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
602+
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
603+
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
604+
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
605+
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
606+
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
607+
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
608+
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
609+
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
610+
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
611+
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
612+
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
613+
# via -r build/test-requirements.txt
591614
typing-extensions==4.12.2 \
592615
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
593616
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8

build/test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ matplotlib~=3.8.4; python_version=="3.10"
2020
matplotlib; python_version>="3.11"
2121
opt-einsum
2222
auditwheel
23+
tensorstore==0.1.72

build_wheel.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Script that builds a JAX wheel, intended to be run via bazel run as part
16+
# of the JAX build process.
17+
18+
import argparse
19+
import os
20+
import pathlib
21+
import shutil
22+
import tempfile
23+
24+
from jaxlib.tools import build_utils
25+
26+
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
27+
parser.add_argument(
28+
"--sources_path",
29+
default=None,
30+
help=(
31+
"Path in which the wheel's sources should be prepared. Optional. If "
32+
"omitted, a temporary directory will be used."
33+
),
34+
)
35+
parser.add_argument(
36+
"--output_path",
37+
default=None,
38+
required=True,
39+
help="Path to which the output wheel should be written. Required.",
40+
)
41+
parser.add_argument(
42+
"--jaxlib_git_hash",
43+
default="",
44+
required=True,
45+
help="Git hash. Empty if unknown. Optional.",
46+
)
47+
parser.add_argument(
48+
"--srcs", help="source files for the wheel", action="append"
49+
)
50+
args = parser.parse_args()
51+
52+
53+
def copy_file(
54+
src_file: str,
55+
dst_dir: str,
56+
) -> None:
57+
"""Copy a file to the destination directory.
58+
59+
Args:
60+
src_file: file to be copied
61+
dst_dir: destination directory
62+
"""
63+
64+
dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
65+
os.makedirs(dest_dir_path, exist_ok=True)
66+
shutil.copy(src_file, dest_dir_path)
67+
os.chmod(os.path.join(dst_dir, src_file), 0o644)
68+
69+
70+
def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
71+
"""Filter the sources and copy them to the destination directory.
72+
73+
Args:
74+
deps: a list of paths to files.
75+
srcs_dir: target directory where files are copied to.
76+
"""
77+
78+
for file in deps:
79+
if not (file.startswith("bazel-out") or file.startswith("external")):
80+
copy_file(file, srcs_dir)
81+
82+
83+
tmpdir = None
84+
sources_path = args.sources_path
85+
if sources_path is None:
86+
tmpdir = tempfile.TemporaryDirectory(prefix="jax")
87+
sources_path = tmpdir.name
88+
89+
try:
90+
os.makedirs(args.output_path, exist_ok=True)
91+
prepare_srcs(args.srcs, pathlib.Path(sources_path))
92+
build_utils.build_wheel(
93+
sources_path,
94+
args.output_path,
95+
package_name="jax",
96+
git_hash=args.jaxlib_git_hash,
97+
)
98+
finally:
99+
if tmpdir:
100+
tmpdir.cleanup()

0 commit comments

Comments
 (0)