Skip to content

Commit 5e7ad60

Browse files
superbobryjax authors
authored andcommitted
Removed the double re-exporting of Pallas GPU/TPU APIs
jax.experimental.pallas.{gpu,tpu} now import directly from the relevant jax._src.pallas.{triton,mosaic} submodules. PiperOrigin-RevId: 641875127
1 parent 3b4039c commit 5e7ad60

File tree

8 files changed

+58
-141
lines changed

8 files changed

+58
-141
lines changed

jax/BUILD

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -623,9 +623,14 @@ pytype_strict_library(
623623
":pallas_tpu_users",
624624
],
625625
deps = [
626-
":pallas", # buildcleaner: keep
626+
":pallas", # build_cleaner: keep
627627
":tpu_custom_call",
628-
"//jax/_src/pallas/mosaic",
628+
"//jax/_src/pallas/mosaic:core",
629+
"//jax/_src/pallas/mosaic:kernel_regeneration_util",
630+
"//jax/_src/pallas/mosaic:lowering",
631+
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
632+
"//jax/_src/pallas/mosaic:pipeline",
633+
"//jax/_src/pallas/mosaic:primitives",
629634
],
630635
)
631636

@@ -663,8 +668,9 @@ pytype_strict_library(
663668
],
664669
deps = [
665670
":pallas",
666-
"//jax/_src/pallas/mosaic_gpu",
667-
"//jax/_src/pallas/triton",
671+
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
672+
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
673+
"//jax/_src/pallas/triton:primitives",
668674
],
669675
)
670676

jax/_src/pallas/BUILD

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ package(
2727

2828
py_library(
2929
name = "pallas",
30-
srcs = glob(
31-
include = ["**/*.py"],
32-
exclude = [
33-
"triton/*.py",
34-
"mosaic/*.py",
35-
],
36-
),
30+
srcs = [
31+
"__init__.py",
32+
"core.py",
33+
"pallas_call.py",
34+
"primitives.py",
35+
"utils.py",
36+
],
3737
deps = [
3838
"//jax",
3939
"//jax:ad_util",
@@ -46,21 +46,3 @@ py_library(
4646
"//jax/_src/lib",
4747
] + py_deps("numpy"),
4848
)
49-
50-
py_library(
51-
name = "gpu",
52-
visibility = [],
53-
deps = [
54-
":pallas",
55-
"//jax/_src/pallas/triton",
56-
],
57-
)
58-
59-
py_library(
60-
name = "tpu",
61-
visibility = [],
62-
deps = [
63-
":pallas",
64-
"//jax/_src/pallas/mosaic",
65-
],
66-
)

jax/_src/pallas/mosaic/BUILD

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
# Package for Mosaic-specific Pallas extensions
1616

1717
load("@rules_python//python:defs.bzl", "py_library")
18-
load(
19-
"//jaxlib:jax.bzl",
20-
"py_deps",
21-
"py_library_providing_imports_info",
22-
)
18+
load("//jaxlib:jax.bzl", "py_deps")
2319

2420
package(
2521
default_applicable_licenses = [],
@@ -28,20 +24,6 @@ package(
2824
],
2925
)
3026

31-
py_library_providing_imports_info(
32-
name = "mosaic",
33-
srcs = ["__init__.py"],
34-
lib_rule = py_library,
35-
deps = [
36-
":core",
37-
":kernel_regeneration_util",
38-
":lowering",
39-
":pallas_call_registration",
40-
":pipeline",
41-
":primitives",
42-
],
43-
)
44-
4527
py_library(
4628
name = "core",
4729
srcs = ["core.py"],

jax/_src/pallas/mosaic/__init__.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,42 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
"""Module for Mosaic lowering of Pallas call."""
16-
17-
from jax._src.pallas.mosaic import core
18-
from jax._src.pallas.mosaic.core import dma_semaphore
19-
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
20-
from jax._src.pallas.mosaic.core import semaphore
21-
from jax._src.pallas.mosaic.core import SemaphoreType
22-
from jax._src.pallas.mosaic.core import TPUMemorySpace
23-
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
24-
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
25-
from jax._src.pallas.mosaic.lowering import LoweringException
26-
from jax._src.pallas.mosaic.pipeline import BufferedRef
27-
from jax._src.pallas.mosaic.pipeline import emit_pipeline
28-
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
29-
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
30-
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
31-
from jax._src.pallas.mosaic.primitives import async_copy
32-
from jax._src.pallas.mosaic.primitives import async_remote_copy
33-
from jax._src.pallas.mosaic.primitives import bitcast
34-
from jax._src.pallas.mosaic.primitives import delay
35-
from jax._src.pallas.mosaic.primitives import device_id
36-
from jax._src.pallas.mosaic.primitives import DeviceIdType
37-
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
38-
from jax._src.pallas.mosaic.primitives import make_async_copy
39-
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
40-
from jax._src.pallas.mosaic.primitives import repeat
41-
from jax._src.pallas.mosaic.primitives import roll
42-
from jax._src.pallas.mosaic.primitives import run_scoped
43-
from jax._src.pallas.mosaic.primitives import semaphore_read
44-
from jax._src.pallas.mosaic.primitives import semaphore_signal
45-
from jax._src.pallas.mosaic.primitives import semaphore_wait
46-
from jax._src.pallas.mosaic.primitives import prng_seed
47-
from jax._src.pallas.mosaic.primitives import prng_random_bits
48-
49-
ANY = TPUMemorySpace.ANY
50-
CMEM = TPUMemorySpace.CMEM
51-
SMEM = TPUMemorySpace.SMEM
52-
VMEM = TPUMemorySpace.VMEM

jax/_src/pallas/triton/BUILD

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
load(
1818
"//jaxlib:jax.bzl",
1919
"py_deps",
20-
"py_library_providing_imports_info",
2120
"pytype_strict_library",
2221
)
2322

@@ -28,18 +27,6 @@ package(
2827
],
2928
)
3029

31-
py_library_providing_imports_info(
32-
name = "triton",
33-
srcs = ["__init__.py"],
34-
lib_rule = pytype_strict_library,
35-
deps = [
36-
":lowering",
37-
":pallas_call_registration",
38-
":primitives",
39-
"//jax/_src/lib",
40-
],
41-
)
42-
4330
pytype_strict_library(
4431
name = "primitives",
4532
srcs = ["primitives.py"],

jax/_src/pallas/triton/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
"""Triton-specific Pallas APIs."""
16-
17-
from jax._src.pallas.triton.primitives import approx_tanh
18-
from jax._src.pallas.triton.primitives import elementwise_inline_asm

jax/experimental/pallas/gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414

1515
"""Triton-specific Pallas APIs."""
1616

17-
from jax._src.pallas.triton import approx_tanh
18-
from jax._src.pallas.triton import elementwise_inline_asm
17+
from jax._src.pallas.triton.primitives import approx_tanh
18+
from jax._src.pallas.triton.primitives import elementwise_inline_asm

jax/experimental/pallas/tpu.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,42 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Contains Mosaic specific Pallas functions."""
16-
from jax._src.pallas.mosaic import ANY
17-
from jax._src.pallas.mosaic import CMEM
18-
from jax._src.pallas.mosaic import PrefetchScalarGridSpec
19-
from jax._src.pallas.mosaic import SMEM
20-
from jax._src.pallas.mosaic import SemaphoreType
21-
from jax._src.pallas.mosaic import TPUMemorySpace
22-
from jax._src.pallas.mosaic import VMEM
23-
from jax._src.pallas.mosaic import DeviceIdType
24-
from jax._src.pallas.mosaic import async_copy
25-
from jax._src.pallas.mosaic import async_remote_copy
26-
from jax._src.pallas.mosaic import bitcast
27-
from jax._src.pallas.mosaic import dma_semaphore
28-
from jax._src.pallas.mosaic import delay
29-
from jax._src.pallas.mosaic import device_id
30-
from jax._src.pallas.mosaic import emit_pipeline_with_allocations
31-
from jax._src.pallas.mosaic import emit_pipeline
32-
from jax._src.pallas.mosaic import get_pipeline_schedule
33-
from jax._src.pallas.mosaic import make_pipeline_allocations
34-
from jax._src.pallas.mosaic import BufferedRef
35-
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
36-
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
37-
from jax._src.pallas.mosaic import get_barrier_semaphore
38-
from jax._src.pallas.mosaic import make_async_copy
39-
from jax._src.pallas.mosaic import make_async_remote_copy
40-
from jax._src.pallas.mosaic import repeat
41-
from jax._src.pallas.mosaic import roll
42-
from jax._src.pallas.mosaic import run_scoped
43-
from jax._src.pallas.mosaic import semaphore
44-
from jax._src.pallas.mosaic import semaphore_read
45-
from jax._src.pallas.mosaic import semaphore_signal
46-
from jax._src.pallas.mosaic import semaphore_wait
15+
"""Mosaic-specific Pallas APIs."""
16+
17+
from jax._src.pallas.mosaic import core
18+
from jax._src.pallas.mosaic.core import dma_semaphore
19+
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
20+
from jax._src.pallas.mosaic.core import semaphore
21+
from jax._src.pallas.mosaic.core import SemaphoreType
22+
from jax._src.pallas.mosaic.core import TPUMemorySpace
23+
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
24+
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
25+
from jax._src.pallas.mosaic.lowering import LoweringException
26+
from jax._src.pallas.mosaic.pipeline import BufferedRef
27+
from jax._src.pallas.mosaic.pipeline import emit_pipeline
28+
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
29+
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
30+
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
31+
from jax._src.pallas.mosaic.primitives import async_copy
32+
from jax._src.pallas.mosaic.primitives import async_remote_copy
33+
from jax._src.pallas.mosaic.primitives import bitcast
34+
from jax._src.pallas.mosaic.primitives import delay
35+
from jax._src.pallas.mosaic.primitives import device_id
36+
from jax._src.pallas.mosaic.primitives import DeviceIdType
37+
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
38+
from jax._src.pallas.mosaic.primitives import make_async_copy
39+
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
40+
from jax._src.pallas.mosaic.primitives import repeat
41+
from jax._src.pallas.mosaic.primitives import roll
42+
from jax._src.pallas.mosaic.primitives import run_scoped
43+
from jax._src.pallas.mosaic.primitives import semaphore_read
44+
from jax._src.pallas.mosaic.primitives import semaphore_signal
45+
from jax._src.pallas.mosaic.primitives import semaphore_wait
46+
from jax._src.pallas.mosaic.primitives import prng_seed
47+
from jax._src.pallas.mosaic.primitives import prng_random_bits
4748
from jax._src.tpu_custom_call import CostEstimate
48-
from jax._src.pallas.mosaic import prng_seed
49-
from jax._src.pallas.mosaic import prng_random_bits
49+
50+
ANY = TPUMemorySpace.ANY
51+
CMEM = TPUMemorySpace.CMEM
52+
SMEM = TPUMemorySpace.SMEM
53+
VMEM = TPUMemorySpace.VMEM

0 commit comments

Comments
 (0)