|
13 | 13 | # limitations under the License. |
14 | 14 | """Helper tool for automatic cost estimation.""" |
15 | 15 | import dataclasses |
| 16 | +import functools |
16 | 17 | import math |
17 | 18 | from typing import Any, Sequence |
18 | 19 |
|
19 | 20 | import jax |
| 21 | +from jax._src import api_util |
20 | 22 | from jax._src import core as jax_core |
21 | 23 | from jax._src import custom_derivatives |
22 | 24 | from jax._src import linear_util as lu |
23 | 25 | from jax._src import pjit |
| 26 | +from jax._src.state import discharge |
24 | 27 | from jax._src.pallas import core as pallas_core |
25 | 28 | from jax._src.interpreters import partial_eval as pe |
26 | 29 | from jax._src.util import safe_map |
@@ -87,10 +90,9 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate: |
87 | 90 | A pallas_core.CostEstimate object containing the cost estimate. |
88 | 91 | """ |
89 | 92 | flattened_args, treedef = jax.tree.flatten(args) |
90 | | - def _partial_fun(*flat_args): |
91 | | - return fun(*jax.tree.unflatten(treedef, flat_args), **kwargs) |
92 | | - wrapped_fun = lu.wrap_init( |
93 | | - lambda *args, **kwargs: (_partial_fun(*args, **kwargs),)) |
| 93 | + partial_fun = functools.partial(fun, **kwargs) |
| 94 | + wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun), |
| 95 | + treedef) |
94 | 96 | avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args] |
95 | 97 | jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) |
96 | 98 | estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) |
@@ -243,3 +245,12 @@ def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_): |
243 | 245 | bytes_accessed=inner_cost.bytes_accessed, |
244 | 246 | ) |
245 | 247 | register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule) |
| 248 | + |
| 249 | +def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2): |
| 250 | + inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr)) |
| 251 | + return CostEstimate( |
| 252 | + flops=inner_cost.flops, |
| 253 | + transcendentals=inner_cost.transcendentals, |
| 254 | + bytes_accessed=inner_cost.bytes_accessed, |
| 255 | + ) |
| 256 | +register_cost_rule(discharge.run_state_p, _run_state_rule) |
0 commit comments