Skip to content

Commit 0b6c355

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
PiperOrigin-RevId: 733112191
1 parent 2c7043f commit 0b6c355

File tree

10 files changed

+4281
-0
lines changed

10 files changed

+4281
-0
lines changed

jax/_src/pallas/fuser/BUILD

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
load(
16+
"//jaxlib:jax.bzl",
17+
"py_deps",
18+
"pytype_strict_library",
19+
)
20+
21+
package(
22+
default_applicable_licenses = [],
23+
default_visibility = [
24+
"//jax:internal",
25+
],
26+
)
27+
28+
pytype_strict_library(
29+
name = "fuser",
30+
srcs = [
31+
"__init__.py",
32+
],
33+
deps = [
34+
":block_spec",
35+
":fusable",
36+
":fusion",
37+
":jaxpr_fusion",
38+
],
39+
)
40+
41+
pytype_strict_library(
42+
name = "block_spec",
43+
srcs = [
44+
"block_spec.py",
45+
],
46+
deps = [
47+
"//jax",
48+
"//jax:ad_util",
49+
"//jax:api_util",
50+
"//jax:core",
51+
"//jax:partial_eval",
52+
"//jax:tree_util",
53+
"//jax:util",
54+
"//jax/_src/pallas",
55+
] + py_deps("numpy"),
56+
)
57+
58+
pytype_strict_library(
59+
name = "fusable",
60+
srcs = [
61+
"fusable.py",
62+
],
63+
deps = [
64+
":fusion",
65+
"//jax",
66+
"//jax:api_util",
67+
"//jax:core",
68+
"//jax:mlir",
69+
"//jax:partial_eval",
70+
"//jax:tree_util",
71+
"//jax:util",
72+
],
73+
)
74+
75+
pytype_strict_library(
76+
name = "fusion",
77+
srcs = [
78+
"fusion.py",
79+
],
80+
deps = [
81+
"//jax",
82+
"//jax:util",
83+
],
84+
)
85+
86+
pytype_strict_library(
87+
name = "jaxpr_fusion",
88+
srcs = [
89+
"jaxpr_fusion.py",
90+
],
91+
deps = [
92+
":fusable",
93+
":fusable_dtype",
94+
":fusion",
95+
"//jax",
96+
"//jax:api_util",
97+
"//jax:core",
98+
"//jax:partial_eval",
99+
"//jax:tree_util",
100+
],
101+
)
102+
103+
pytype_strict_library(
104+
name = "fusable_dtype",
105+
srcs = [
106+
"fusable_dtype.py",
107+
],
108+
deps = [
109+
":block_spec",
110+
":fusable",
111+
"//jax",
112+
"//jax:api_util",
113+
"//jax:core",
114+
"//jax:dtypes",
115+
"//jax:partial_eval",
116+
"//jax:source_info_util",
117+
"//jax:tree_util",
118+
"//jax:util",
119+
"//jax/_src/pallas",
120+
],
121+
)

jax/_src/pallas/fuser/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_values
16+
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
17+
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
18+
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
19+
from jax._src.pallas.fuser.fusable import fusable as fusable
20+
from jax._src.pallas.fuser.fusion import Fusion as Fusion
21+
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

0 commit comments

Comments
 (0)