Skip to content

Commit c6b164d

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas/Fuser] Add custom evaluate to allow/disallow transposes
PiperOrigin-RevId: 735931978
1 parent f45cbf3 commit c6b164d

File tree

7 files changed

+152
-17
lines changed

7 files changed

+152
-17
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,7 @@ pytype_strict_library(
690690
deps = [
691691
":pallas", # build_cleaner: keep
692692
"//jax/_src/pallas/fuser:block_spec",
693+
"//jax/_src/pallas/fuser:custom_evaluate",
693694
"//jax/_src/pallas/fuser:fusable",
694695
"//jax/_src/pallas/fuser:fusion",
695696
"//jax/_src/pallas/fuser:jaxpr_fusion",

jax/_src/pallas/fuser/BUILD

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pytype_strict_library(
3232
],
3333
deps = [
3434
":block_spec",
35+
":custom_evaluate",
3536
":fusable",
3637
":fusion",
3738
":jaxpr_fusion",
@@ -44,6 +45,7 @@ pytype_strict_library(
4445
"block_spec.py",
4546
],
4647
deps = [
48+
":fuser_utils",
4749
"//jax",
4850
"//jax:ad_util",
4951
"//jax:api_util",
@@ -119,3 +121,27 @@ pytype_strict_library(
119121
"//jax/_src/pallas",
120122
],
121123
)
124+
125+
pytype_strict_library(
126+
name = "custom_evaluate",
127+
srcs = ["custom_evaluate.py"],
128+
deps = [
129+
":fuser_utils",
130+
"//jax",
131+
"//jax:core",
132+
"//jax:source_info_util",
133+
"//jax:tree_util",
134+
"//jax:util",
135+
],
136+
)
137+
138+
pytype_strict_library(
139+
name = "fuser_utils",
140+
srcs = ["fuser_utils.py"],
141+
deps = [
142+
"//jax:api_util",
143+
"//jax:core",
144+
"//jax:partial_eval",
145+
"//jax:tree_util",
146+
],
147+
)

jax/_src/pallas/fuser/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
1717
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
1818
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
19+
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
1920
from jax._src.pallas.fuser.fusable import fusable as fusable
2021
from jax._src.pallas.fuser.fusion import Fusion as Fusion
2122
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

jax/_src/pallas/fuser/block_spec.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,14 @@
2626
import jax
2727
from jax import lax
2828
from jax._src import ad_util
29-
from jax._src import api_util
3029
from jax._src import core
3130
from jax._src import custom_derivatives
32-
from jax._src import linear_util as lu
3331
from jax._src import pjit
3432
from jax._src import tree_util
3533
from jax._src import util
3634
from jax._src.interpreters import partial_eval as pe
3735
from jax._src.pallas import core as pallas_core
36+
from jax._src.pallas.fuser import fuser_utils
3837
import jax.numpy as jnp
3938
import numpy as np
4039

@@ -226,18 +225,6 @@ def new_index_map(*args):
226225
return out_block_spec
227226

228227

229-
def _make_jaxpr(f, *args, **kwargs):
230-
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
231-
flat_avals = [core.get_aval(x) for x in flat_args]
232-
debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs)
233-
flat_fun, out_tree_thunk = api_util.flatten_fun(
234-
lu.wrap_init(f, debug_info=debug_info), in_tree
235-
)
236-
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
237-
out_tree = out_tree_thunk()
238-
return jaxpr, consts, in_tree, out_tree
239-
240-
241228
def pull_block_spec(
242229
f: Callable,
243230
out_block_specs: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...],
@@ -246,7 +233,9 @@ def pull_block_spec(
246233
grid: tuple[int | jax.Array, ...] | None = None,
247234
):
248235
def wrapped(*args, **kwargs):
249-
jaxpr, consts, in_tree, out_tree_ = _make_jaxpr(f, *args, **kwargs)
236+
jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr(
237+
f, *args, **kwargs
238+
)
250239
# TODO(sharadmv): handle these consts better, they should correspond to
251240
# scalar prefetch.
252241
del consts, out_tree_
@@ -563,7 +552,9 @@ def write_env(var, val):
563552
def get_fusion_values(
564553
fusion: Callable, *args, **kwargs
565554
) -> tuple[Callable, tuple[jax.Array, ...], tuple[jax.Array, ...]]:
566-
jaxpr, values, in_tree, out_tree = _make_jaxpr(fusion, *args, **kwargs)
555+
jaxpr, values, in_tree, out_tree = fuser_utils.make_jaxpr(
556+
fusion, *args, **kwargs
557+
)
567558
assert len(values) == len(jaxpr.constvars), (jaxpr, values)
568559
out_usages = tuple({Usage.REGULAR} for _ in jaxpr.outvars)
569560
read_usage_env = compute_usage(jaxpr, out_usages)
@@ -1325,7 +1316,7 @@ def wrapper(*args, **kwargs):
13251316
flat_block_specs, in_tree_ = tree_util.tree_flatten(
13261317
(in_spec_args, in_spec_kwargs)
13271318
)
1328-
jaxpr, _, in_tree, out_tree = _make_jaxpr(f, *args, **kwargs)
1319+
jaxpr, _, in_tree, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
13291320
if in_tree != in_tree_:
13301321
raise ValueError(f'Expected {in_tree} PyTree, got {in_tree_}')
13311322
out_bs = _push_block_spec_jaxpr(jaxpr, *flat_block_specs)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helpers for evaluating functions under certain constraints."""
16+
import dataclasses
17+
from typing import Any
18+
19+
from jax import lax
20+
from jax._src import core
21+
from jax._src import source_info_util
22+
from jax._src import tree_util
23+
from jax._src import util
24+
from jax._src.pallas.fuser import fuser_utils
25+
26+
27+
@dataclasses.dataclass
28+
class CustomEvaluateSettings:
29+
allow_transpose: bool = True
30+
31+
32+
def evaluate(f, *, allow_transpose: bool = True):
33+
def wrapped(*args, **kwargs):
34+
jaxpr, consts, _, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
35+
settings = CustomEvaluateSettings(allow_transpose=allow_transpose)
36+
flat_args = tree_util.tree_leaves(args)
37+
out_flat = _custom_evaluate_jaxpr(settings, jaxpr, consts, *flat_args)
38+
return tree_util.tree_unflatten(out_tree, out_flat)
39+
40+
return wrapped
41+
42+
43+
# Disallow most higher-order primitives for now.
44+
disallowed_primitives = {lax.scan_p, lax.while_p, lax.cond_p}
45+
46+
47+
def _custom_evaluate_jaxpr(
48+
settings: CustomEvaluateSettings, jaxpr: core.Jaxpr, consts, *args
49+
):
50+
def read(v: core.Atom) -> Any:
51+
return v.val if isinstance(v, core.Literal) else env[v]
52+
53+
def write(v: core.Var, val: Any) -> None:
54+
env[v] = val
55+
56+
env: dict[core.Var, Any] = {}
57+
util.safe_map(write, jaxpr.constvars, consts)
58+
util.safe_map(write, jaxpr.invars, args)
59+
lu = core.last_used(jaxpr)
60+
for eqn in jaxpr.eqns:
61+
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
62+
63+
if eqn.primitive in disallowed_primitives:
64+
raise NotImplementedError(f'Primitive {eqn.primitive} not supported.')
65+
if not settings.allow_transpose and eqn.primitive is lax.transpose_p:
66+
raise ValueError('Transpose not allowed.')
67+
name_stack = (
68+
source_info_util.current_name_stack() + eqn.source_info.name_stack
69+
)
70+
traceback = eqn.source_info.traceback
71+
with source_info_util.user_context(
72+
traceback, name_stack=name_stack
73+
), eqn.ctx.manager:
74+
ans = eqn.primitive.bind(
75+
*subfuns, *util.safe_map(read, eqn.invars), **bind_params
76+
)
77+
if eqn.primitive.multiple_results:
78+
util.safe_map(write, eqn.outvars, ans)
79+
else:
80+
write(eqn.outvars[0], ans)
81+
core.clean_up_dead_vars(eqn, env, lu)
82+
return util.safe_map(read, jaxpr.outvars)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Basic utils for fuser internals."""
16+
from jax._src import api_util
17+
from jax._src import core
18+
from jax._src import linear_util as lu
19+
from jax._src import tree_util
20+
from jax._src.interpreters import partial_eval as pe
21+
22+
23+
24+
def make_jaxpr(f, *args, **kwargs):
25+
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
26+
flat_avals = [core.get_aval(x) for x in flat_args]
27+
debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs)
28+
flat_fun, out_tree_thunk = api_util.flatten_fun(
29+
lu.wrap_init(f, debug_info=debug_info), in_tree
30+
)
31+
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
32+
out_tree = out_tree_thunk()
33+
return jaxpr, consts, in_tree, out_tree

jax/experimental/pallas/fuser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
1919
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
2020
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
21+
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
2122
from jax._src.pallas.fuser.fusable import fusable as fusable
2223
from jax._src.pallas.fuser.fusion import Fusion as Fusion
2324
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

0 commit comments

Comments
 (0)