File tree Expand file tree Collapse file tree 10 files changed +4281
-0
lines changed Expand file tree Collapse file tree 10 files changed +4281
-0
lines changed Original file line number Diff line number Diff line change 1+ # Copyright 2025 The JAX Authors.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # https://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ load (
16+ "//jaxlib:jax.bzl" ,
17+ "py_deps" ,
18+ "pytype_strict_library" ,
19+ )
20+
21+ package (
22+ default_applicable_licenses = [],
23+ default_visibility = [
24+ "//jax:internal" ,
25+ ],
26+ )
27+
28+ pytype_strict_library (
29+ name = "fuser" ,
30+ srcs = [
31+ "__init__.py" ,
32+ ],
33+ deps = [
34+ ":block_spec" ,
35+ ":fusable" ,
36+ ":fusion" ,
37+ ":jaxpr_fusion" ,
38+ ],
39+ )
40+
41+ pytype_strict_library (
42+ name = "block_spec" ,
43+ srcs = [
44+ "block_spec.py" ,
45+ ],
46+ deps = [
47+ "//jax" ,
48+ "//jax:ad_util" ,
49+ "//jax:api_util" ,
50+ "//jax:core" ,
51+ "//jax:partial_eval" ,
52+ "//jax:tree_util" ,
53+ "//jax:util" ,
54+ "//jax/_src/pallas" ,
55+ ] + py_deps ("numpy" ),
56+ )
57+
58+ pytype_strict_library (
59+ name = "fusable" ,
60+ srcs = [
61+ "fusable.py" ,
62+ ],
63+ deps = [
64+ ":fusion" ,
65+ "//jax" ,
66+ "//jax:api_util" ,
67+ "//jax:core" ,
68+ "//jax:mlir" ,
69+ "//jax:partial_eval" ,
70+ "//jax:tree_util" ,
71+ "//jax:util" ,
72+ ],
73+ )
74+
75+ pytype_strict_library (
76+ name = "fusion" ,
77+ srcs = [
78+ "fusion.py" ,
79+ ],
80+ deps = [
81+ "//jax" ,
82+ "//jax:util" ,
83+ ],
84+ )
85+
86+ pytype_strict_library (
87+ name = "jaxpr_fusion" ,
88+ srcs = [
89+ "jaxpr_fusion.py" ,
90+ ],
91+ deps = [
92+ ":fusable" ,
93+ ":fusable_dtype" ,
94+ ":fusion" ,
95+ "//jax" ,
96+ "//jax:api_util" ,
97+ "//jax:core" ,
98+ "//jax:partial_eval" ,
99+ "//jax:tree_util" ,
100+ ],
101+ )
102+
103+ pytype_strict_library (
104+ name = "fusable_dtype" ,
105+ srcs = [
106+ "fusable_dtype.py" ,
107+ ],
108+ deps = [
109+ ":block_spec" ,
110+ ":fusable" ,
111+ "//jax" ,
112+ "//jax:api_util" ,
113+ "//jax:core" ,
114+ "//jax:dtypes" ,
115+ "//jax:partial_eval" ,
116+ "//jax:source_info_util" ,
117+ "//jax:tree_util" ,
118+ "//jax:util" ,
119+ "//jax/_src/pallas" ,
120+ ],
121+ )
Original file line number Diff line number Diff line change 1+ # Copyright 2025 The JAX Authors.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # https://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ from jax ._src .pallas .fuser .block_spec import get_fusion_values as get_fusion_values
16+ from jax ._src .pallas .fuser .block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
17+ from jax ._src .pallas .fuser .block_spec import pull_block_spec as pull_block_spec
18+ from jax ._src .pallas .fuser .block_spec import push_block_spec as push_block_spec
19+ from jax ._src .pallas .fuser .fusable import fusable as fusable
20+ from jax ._src .pallas .fuser .fusion import Fusion as Fusion
21+ from jax ._src .pallas .fuser .jaxpr_fusion import fuse as fuse
You can’t perform that action at this time.
0 commit comments