Skip to content

Commit 641a1d5

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Add support for run_state to cost estimator.
PiperOrigin-RevId: 703543961
1 parent 72df8e0 commit 641a1d5

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

jax/_src/pallas/cost_estimate.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414
"""Helper tool for automatic cost estimation."""
1515
import dataclasses
16+
import functools
1617
import math
1718
from typing import Any, Sequence
1819

1920
import jax
21+
from jax._src import api_util
2022
from jax._src import core as jax_core
2123
from jax._src import custom_derivatives
2224
from jax._src import linear_util as lu
2325
from jax._src import pjit
26+
from jax._src.state import discharge
2427
from jax._src.pallas import core as pallas_core
2528
from jax._src.interpreters import partial_eval as pe
2629
from jax._src.util import safe_map
@@ -87,10 +90,9 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate:
8790
A pallas_core.CostEstimate object containing the cost estimate.
8891
"""
8992
flattened_args, treedef = jax.tree.flatten(args)
90-
def _partial_fun(*flat_args):
91-
return fun(*jax.tree.unflatten(treedef, flat_args), **kwargs)
92-
wrapped_fun = lu.wrap_init(
93-
lambda *args, **kwargs: (_partial_fun(*args, **kwargs),))
93+
partial_fun = functools.partial(fun, **kwargs)
94+
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun),
95+
treedef)
9496
avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
9597
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
9698
estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts))
@@ -243,3 +245,12 @@ def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_):
243245
bytes_accessed=inner_cost.bytes_accessed,
244246
)
245247
register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule)
248+
249+
def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2):
250+
inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr))
251+
return CostEstimate(
252+
flops=inner_cost.flops,
253+
transcendentals=inner_cost.transcendentals,
254+
bytes_accessed=inner_cost.bytes_accessed,
255+
)
256+
register_cost_rule(discharge.run_state_p, _run_state_rule)

tests/pallas/pallas_cost_estimate_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from jax._src import config
2020
from jax._src import test_util as jtu
2121
from jax._src.pallas import cost_estimate
22+
from jax._src.state import discharge
2223

2324

2425
config.parse_flags_with_absl()
@@ -91,5 +92,23 @@ def test_integer_pow(self, power, expected_flops_per_element):
9192
self.assertEqual(cost.transcendentals, 0)
9293
self.assertEqual(cost.bytes_accessed, 80)
9394

95+
def test_run_state(self):
96+
def add_refs(refs):
97+
x_ref, y_ref, z_ref = refs
98+
x = x_ref[:]
99+
y = y_ref[:]
100+
z = x + y
101+
z_ref[:] = z
102+
input_shape = jax.ShapeDtypeStruct((100,), jnp.float32)
103+
cost = cost_estimate.estimate_cost(
104+
discharge.run_state(add_refs),
105+
(input_shape, input_shape, input_shape))
106+
self.assertEqual(cost.flops, 100)
107+
self.assertEqual(cost.transcendentals, 0)
108+
# TODO(justinfu): This is off by a factor of 2 because run_state
109+
# has all inputs/outputs as both arguments and return values.
110+
self.assertEqual(cost.bytes_accessed / 2, 3 * 4 * 100)
111+
112+
94113
if __name__ == "__main__":
95114
absltest.main()

0 commit comments

Comments
 (0)