Skip to content

Commit 0e611e5

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Add a cost estimator for Pallas/JAX functions.
Helps resolve the following issue, where invoking HLO's cost analysis triggers compilation which can OOM for larger inputs: jax-ml#24539. This cost estimator uses only abstract evaluation which should work for all input sizes. PiperOrigin-RevId: 695415760
1 parent 0995bc2 commit 0e611e5

File tree

4 files changed

+325
-0
lines changed

4 files changed

+325
-0
lines changed

jax/_src/pallas/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ py_library(
3030
srcs = [
3131
"__init__.py",
3232
"core.py",
33+
"cost_estimate.py",
3334
"pallas_call.py",
3435
"primitives.py",
3536
"utils.py",

jax/_src/pallas/cost_estimate.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Helper tool for automatic cost estimation."""
15+
import dataclasses
16+
import math
17+
from typing import Any, Sequence
18+
19+
from jax._src import core as jax_core
20+
from jax._src.pallas import core as pallas_core
21+
from jax._src import linear_util as lu
22+
from jax._src.interpreters import partial_eval as pe
23+
from jax._src.util import safe_map
24+
from jax._src.util import safe_zip
25+
from jax._src.lax import lax
26+
27+
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
28+
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
29+
30+
_cost_rules = {}
31+
32+
@dataclasses.dataclass(frozen=True)
33+
class CostEstimate:
34+
flops: int
35+
transcendentals: int
36+
bytes_accessed: int
37+
38+
def __add__(self, other: 'CostEstimate') -> 'CostEstimate':
39+
return CostEstimate(
40+
flops=self.flops + other.flops,
41+
transcendentals=self.transcendentals + other.transcendentals,
42+
bytes_accessed=self.bytes_accessed + other.bytes_accessed,
43+
)
44+
45+
def register_cost_rule(primitive: jax_core.Primitive, rule):
46+
_cost_rules[primitive] = rule
47+
48+
@dataclasses.dataclass(frozen=True)
49+
class Context:
50+
avals_in: Sequence[Any]
51+
avals_out: Sequence[Any]
52+
53+
def cost_estimate_jaxpr(
54+
jaxpr: jax_core.ClosedJaxpr,
55+
) -> pallas_core.CostEstimate:
56+
"""Returns the cost estimate for the given Jaxpr."""
57+
jaxpr, _ = jaxpr.jaxpr, jaxpr.consts
58+
total_cost = CostEstimate(flops=0, transcendentals=0, bytes_accessed=0)
59+
60+
for eqn in jaxpr.eqns:
61+
_, bind_params = eqn.primitive.get_bind_params(eqn.params)
62+
rule = _cost_rules.get(eqn.primitive, None)
63+
if rule is not None:
64+
context = Context(avals_in=[v.aval for v in eqn.invars],
65+
avals_out=[v.aval for v in eqn.outvars])
66+
op_cost = rule(context, **bind_params)
67+
total_cost = total_cost + op_cost
68+
return pallas_core.CostEstimate(
69+
flops=total_cost.flops,
70+
transcendentals=total_cost.transcendentals,
71+
bytes_accessed=total_cost.bytes_accessed,
72+
)
73+
74+
def cost_estimate(fun, *args) -> pallas_core.CostEstimate:
75+
"""Computes a cost estimate for the given function.
76+
77+
Args:
78+
fun: The function to compute the cost estimate for.
79+
*args: The arguments to the function. Can be jax.ShapeDtypeStruct or
80+
jax.Array.
81+
82+
Returns:
83+
A pallas_core.CostEstimate object containing the cost estimate.
84+
"""
85+
wrapped_fun = lu.wrap_init(lambda *args, **kwargs: (fun(*args, **kwargs),))
86+
avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in args]
87+
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
88+
estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts))
89+
input_bytes = sum(math.prod(a.shape) * a.dtype.itemsize for a in args)
90+
output_bytes = sum(
91+
math.prod(a.aval.shape) * a.aval.dtype.itemsize for a in jaxpr.outvars)
92+
return pallas_core.CostEstimate(
93+
flops=estimate.flops,
94+
transcendentals=estimate.transcendentals,
95+
bytes_accessed=estimate.bytes_accessed + input_bytes + output_bytes,
96+
)
97+
98+
def binary_cost_rule(ctx: Context, **_) -> CostEstimate:
99+
aval_out, = ctx.avals_out
100+
out_flops = math.prod(aval_out.shape)
101+
return CostEstimate(
102+
flops=out_flops,
103+
transcendentals=0,
104+
bytes_accessed=0,
105+
)
106+
BINARY_OPS = [
107+
lax.add_p,
108+
lax.mul_p,
109+
lax.sub_p,
110+
lax.div_p,
111+
lax.min_p,
112+
lax.max_p,
113+
lax.or_p,
114+
lax.and_p,
115+
lax.xor_p,
116+
]
117+
for op in BINARY_OPS:
118+
register_cost_rule(op, binary_cost_rule)
119+
120+
121+
def unary_cost_rule(transcendental: bool):
122+
def cost_rule(ctx: Context, **_) -> CostEstimate:
123+
x_aval, = ctx.avals_in
124+
new_flops = 0
125+
new_transcendentals = 0
126+
if transcendental:
127+
new_transcendentals += math.prod(x_aval.shape)
128+
else:
129+
new_flops += math.prod(x_aval.shape)
130+
return CostEstimate(
131+
flops=new_flops,
132+
transcendentals=new_transcendentals,
133+
bytes_accessed=0,
134+
)
135+
return cost_rule
136+
137+
UN_OPS = [
138+
lax.neg_p,
139+
lax.floor_p,
140+
lax.ceil_p,
141+
lax.round_p,
142+
lax.not_p,
143+
]
144+
for op in UN_OPS:
145+
register_cost_rule(op, unary_cost_rule(transcendental=False))
146+
147+
TRANSCENDENTAL_OPS = [
148+
lax.cos_p,
149+
lax.sin_p,
150+
lax.tan_p,
151+
lax.sinh_p,
152+
lax.cosh_p,
153+
lax.tanh_p,
154+
lax.acos_p,
155+
lax.asin_p,
156+
lax.atan_p,
157+
lax.exp_p,
158+
lax.log_p,
159+
lax.logistic_p,
160+
lax.sqrt_p,
161+
]
162+
for op in TRANSCENDENTAL_OPS:
163+
register_cost_rule(op, unary_cost_rule(transcendental=True))
164+
165+
def _integer_pow_cost_rule(ctx: Context, *, y: int) -> CostEstimate:
166+
x_aval, = ctx.avals_in
167+
num_elements = math.prod(x_aval.shape)
168+
if y == 0 or y == 1:
169+
# No flops, the result is 0 or a copy of the input.
170+
cost_per_element = 0
171+
else:
172+
# We assume integer pow is implemented using repeated squaring.
173+
# The cost is log(y) squarings, plus one multiply per non-zero bit.
174+
highest_bit = math.floor(math.log(y, 2))
175+
cost_per_element = highest_bit + y.bit_count()
176+
return CostEstimate(
177+
flops=num_elements * cost_per_element,
178+
transcendentals=0,
179+
bytes_accessed=0,
180+
)
181+
register_cost_rule(lax.integer_pow_p, _integer_pow_cost_rule)
182+
183+
def dot_general_cost_rule(ctx: Context,
184+
dimension_numbers: lax.DotDimensionNumbers,
185+
**_) -> CostEstimate:
186+
x_aval, y_aval = ctx.avals_in
187+
x_shape, y_shape = x_aval.shape, y_aval.shape
188+
(lhs_contracting_dims, rhs_contracting_dims), (
189+
lhs_batch_dims, rhs_batch_dims) = dimension_numbers
190+
assert len(lhs_contracting_dims) == len(rhs_contracting_dims)
191+
assert len(lhs_batch_dims) == len(rhs_batch_dims)
192+
flops = 1
193+
# Flops along a contracting dim is 2*dim (addition and multiplication)
194+
for i in range(len(lhs_contracting_dims)):
195+
lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i]
196+
assert x_shape[lhs_dim] == y_shape[rhs_dim]
197+
flops *= 2 * x_shape[lhs_dim]
198+
# Now we handle all other dimensions.
199+
for i, lhs_dim in enumerate(x_shape):
200+
if i in lhs_contracting_dims:
201+
continue
202+
flops *= lhs_dim
203+
for i, rhs_dim in enumerate(y_shape):
204+
if i in rhs_contracting_dims:
205+
continue
206+
# Don't double-count batch dims (we already counted for LHS)
207+
if i in rhs_batch_dims:
208+
continue
209+
flops *= rhs_dim
210+
return CostEstimate(
211+
flops=flops,
212+
transcendentals=0,
213+
bytes_accessed=0,
214+
)
215+
register_cost_rule(lax.dot_general_p, dot_general_cost_rule)

tests/pallas/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ jax_multiplatform_test(
5656
] + py_deps("absl/testing") + py_deps("numpy"),
5757
)
5858

59+
jax_multiplatform_test(
60+
name = "pallas_cost_estimate_test",
61+
srcs = [
62+
"pallas_cost_estimate_test.py",
63+
],
64+
deps = [
65+
"//jax:pallas",
66+
"//jax:pallas_gpu",
67+
"//jax:pallas_gpu_ops",
68+
"//jax:pallas_tpu",
69+
"//jax:pallas_tpu_ops",
70+
] + py_deps("absl/testing") + py_deps("numpy"),
71+
)
72+
5973
jax_multiplatform_test(
6074
name = "pallas_jumble_test",
6175
srcs = [
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from absl.testing import absltest
15+
from absl.testing import parameterized
16+
import jax
17+
from jax import lax
18+
from jax import numpy as jnp
19+
from jax._src import config
20+
from jax._src import test_util as jtu
21+
from jax._src.pallas import cost_estimate
22+
23+
24+
config.parse_flags_with_absl()
25+
26+
27+
class PallasCostEstimateTest(jtu.JaxTestCase):
28+
29+
def test_exp_add(self):
30+
def exp_add(x, y):
31+
return jnp.exp(x + y)
32+
cost = cost_estimate.cost_estimate(exp_add,
33+
jnp.ones(10, dtype=jnp.float32),
34+
jnp.ones(10, dtype=jnp.float32))
35+
self.assertEqual(cost.flops, 10)
36+
self.assertEqual(cost.transcendentals, 10)
37+
self.assertEqual(cost.bytes_accessed, 4 * 30)
38+
39+
def test_very_large_matmul(self):
40+
def matmul(a, b):
41+
return a @ b
42+
m, k, n = 400_000, 800_000, 900_000
43+
cost = cost_estimate.cost_estimate(
44+
matmul,
45+
jax.ShapeDtypeStruct((m, k), jnp.bfloat16),
46+
jax.ShapeDtypeStruct((k, n), jnp.bfloat16))
47+
self.assertEqual(cost.flops, 2*m*k*n)
48+
self.assertEqual(cost.transcendentals, 0)
49+
self.assertEqual(cost.bytes_accessed, 2*(m*k + n*k + m*n))
50+
51+
def test_batched_matmul(self):
52+
def matmul(a, b):
53+
return jnp.matmul(a, b)
54+
b, m, k, n = 7, 37, 91, 23
55+
cost = cost_estimate.cost_estimate(
56+
matmul,
57+
jax.ShapeDtypeStruct((b, m, k), jnp.float32),
58+
jax.ShapeDtypeStruct((b, k, n), jnp.float32))
59+
self.assertEqual(cost.flops, 2*b*m*k*n)
60+
self.assertEqual(cost.transcendentals, 0)
61+
self.assertEqual(cost.bytes_accessed, 4*(b*m*k + b*n*k + b*m*n))
62+
63+
def test_attention(self):
64+
qk_dim = 16
65+
v_dim = 4
66+
kv_len = 128
67+
q_len = 64
68+
def attention(q, k, v):
69+
return jax.nn.softmax(q @ k.T, axis=-1) @ v
70+
cost = cost_estimate.cost_estimate(
71+
attention,
72+
jnp.zeros((q_len, qk_dim), dtype=jnp.float32),
73+
jnp.zeros((kv_len, qk_dim), dtype=jnp.float32),
74+
jnp.zeros((kv_len, v_dim), dtype=jnp.float32))
75+
qk_cost = 2 * q_len * kv_len * qk_dim
76+
v_cost = 2 * q_len * kv_len * v_dim
77+
softmax_flops = kv_len * q_len
78+
self.assertEqual(cost.flops, qk_cost + v_cost + 2 * softmax_flops + q_len)
79+
self.assertEqual(cost.transcendentals, softmax_flops)
80+
input_bytes = q_len * qk_dim + kv_len * qk_dim + kv_len * v_dim
81+
output_bytes = q_len * v_dim
82+
self.assertEqual(cost.bytes_accessed, 4 * (input_bytes + output_bytes))
83+
84+
@parameterized.parameters(
85+
(1, 0), (7, 5), (8, 4), (9, 5)
86+
)
87+
def test_integer_pow(self, power, expected_flops_per_element):
88+
cost = cost_estimate.cost_estimate(lambda x: lax.integer_pow(x, power),
89+
jnp.ones(10, dtype=jnp.float32))
90+
self.assertEqual(cost.flops, 10 * expected_flops_per_element)
91+
self.assertEqual(cost.transcendentals, 0)
92+
self.assertEqual(cost.bytes_accessed, 80)
93+
94+
if __name__ == "__main__":
95+
absltest.main()

0 commit comments

Comments
 (0)