Skip to content

Commit bdee4c3

Browse files
Merge pull request #25153 from epiqueras:feature/typechecker
PiperOrigin-RevId: 700893189
2 parents b62ca8b + 8c52154 commit bdee4c3

File tree

6 files changed

+1074
-0
lines changed

6 files changed

+1074
-0
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ py_library_providing_imports_info(
227227
"_src/state/**/*.py",
228228
"_src/third_party/**/*.py",
229229
"experimental/key_reuse/**/*.py",
230+
"experimental/roofline/**/*.py",
230231
"image/**/*.py",
231232
"interpreters/**/*.py",
232233
"lax/**/*.py",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from jax.experimental.roofline.roofline import (
15+
RooflineRuleContext as RooflineRuleContext,
16+
)
17+
from jax.experimental.roofline.roofline import RooflineShape as RooflineShape
18+
from jax.experimental.roofline.roofline import RooflineResult as RooflineResult
19+
from jax.experimental.roofline.roofline import roofline as roofline
20+
from jax.experimental.roofline.roofline import register_roofline as register_roofline
21+
from jax.experimental.roofline.roofline import (
22+
register_standard_roofline as register_standard_roofline,
23+
)
24+
from jax.experimental.roofline.roofline import roofline_and_grad as roofline_and_grad
25+
26+
27+
import jax.experimental.roofline.rooflines as rooflines
28+
29+
del rooflines
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from dataclasses import dataclass, field
17+
from typing import Any, Callable, Protocol, Sequence
18+
import numpy as np
19+
20+
import jax.numpy as jnp
21+
from jax.sharding import NamedSharding
22+
from jax._src import api
23+
from jax._src import core
24+
from jax._src import source_info_util
25+
from jax._src import traceback_util
26+
from jax._src import util
27+
from jax._src.api import make_jaxpr
28+
from jax._src.interpreters.partial_eval import dce_jaxpr
29+
from jax._src.interpreters.xla import abstractify
30+
from jax._src.mesh import AbstractMesh, Mesh
31+
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
32+
from jax.experimental import shard_map
33+
34+
35+
ShapeDtypeStructTree = Any
36+
37+
38+
map = util.safe_map
39+
40+
41+
@dataclass(frozen=True, slots=True, kw_only=True)
42+
class RooflineRuleContext:
43+
name_stack: source_info_util.NameStack
44+
primitive: core.Primitive
45+
avals_in: Sequence[core.AbstractValue]
46+
avals_out: Sequence[core.AbstractValue]
47+
jaxpr_eqn_ctx: core.JaxprEqnContext
48+
mesh: Mesh | AbstractMesh
49+
pin_lhs_in_vmem: bool
50+
pin_rhs_in_vmem: bool
51+
52+
53+
@dataclass(frozen=True, slots=True, kw_only=True)
54+
class RooflineShape:
55+
shape: tuple[int, ...]
56+
dtype: np.dtype
57+
58+
@classmethod
59+
def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape":
60+
if not isinstance(aval, core.ShapedArray):
61+
raise TypeError(f"Expected ShapedArray, got {type(aval)}.")
62+
if not isinstance(aval.dtype, np.dtype):
63+
raise TypeError(f"Expected numpy dtype, got {type(aval.dtype)}.")
64+
return cls(shape=aval.shape, dtype=aval.dtype)
65+
66+
@property
67+
def size(self) -> int:
68+
return int(np.prod(self.shape))
69+
70+
@property
71+
def bytes(self) -> int:
72+
return int(self.size * self.dtype.itemsize)
73+
74+
@classmethod
75+
def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int:
76+
return sum(cls.from_aval(aval).bytes for aval in avals)
77+
78+
79+
@dataclass(frozen=True, slots=True, kw_only=True)
80+
class RooflineResult:
81+
flops: int = 0
82+
ici_bytes: dict[str, int] = field(default_factory=dict)
83+
ici_latency: dict[str, int] = field(default_factory=dict)
84+
hbm_bytes: int = 0
85+
peak_hbm_bytes: int = 0
86+
87+
@classmethod
88+
def zeros(cls) -> "RooflineResult":
89+
return cls()
90+
91+
def __add__(self, other: "RooflineResult") -> "RooflineResult":
92+
def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]:
93+
return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)}
94+
95+
return RooflineResult(
96+
flops=self.flops + other.flops,
97+
ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes),
98+
ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency),
99+
hbm_bytes=self.hbm_bytes + other.hbm_bytes,
100+
peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes),
101+
)
102+
103+
def __mul__(self, constant: int | float) -> "RooflineResult":
104+
return RooflineResult(
105+
flops=int(self.flops * constant),
106+
ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()},
107+
ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()},
108+
hbm_bytes=int(self.hbm_bytes * constant),
109+
peak_hbm_bytes=int(self.peak_hbm_bytes * constant),
110+
)
111+
112+
def __rmul__(self, constant: int | float) -> "RooflineResult":
113+
return self.__mul__(constant)
114+
115+
116+
class _RooflineRule(Protocol):
117+
def __call__(
118+
self, ctx: RooflineRuleContext, *args: RooflineShape, **kw
119+
) -> RooflineResult: ...
120+
121+
122+
_rooflines: dict[core.Primitive, _RooflineRule] = {}
123+
124+
125+
def _roofline_interpreter(
126+
f_name: str,
127+
jaxpr: core.Jaxpr,
128+
mesh: Mesh | AbstractMesh,
129+
*,
130+
pin_lhs_in_vmem: bool = False,
131+
pin_rhs_in_vmem: bool = False,
132+
) -> RooflineResult:
133+
name_stack = source_info_util.new_name_stack(util.wrap_name(f_name, "roofline"))
134+
135+
result = RooflineResult.zeros()
136+
137+
env: dict[core.Var, RooflineShape] = {}
138+
139+
def write(v: core.Var, node: RooflineShape):
140+
assert node is not None
141+
env[v] = node
142+
143+
def read(v: core.Atom) -> RooflineShape:
144+
if type(v) is core.Literal:
145+
return RooflineShape.from_aval(abstractify(v.val))
146+
else:
147+
assert isinstance(v, core.Var)
148+
return env[v]
149+
150+
def aval(v: core.Atom) -> core.AbstractValue:
151+
if type(v) is core.Literal:
152+
return abstractify(v.val)
153+
else:
154+
return v.aval
155+
156+
def calculate_peak_hbm_bytes() -> int:
157+
return int(
158+
sum(np.prod(shape.shape) * shape.dtype.itemsize for shape in env.values())
159+
)
160+
161+
make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x))
162+
map(
163+
write,
164+
jaxpr.constvars,
165+
map(make_roofline_shape, jaxpr.constvars),
166+
)
167+
map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
168+
last_used = core.last_used(jaxpr)
169+
for eqn in jaxpr.eqns:
170+
source_info = eqn.source_info.replace(
171+
name_stack=name_stack + eqn.source_info.name_stack
172+
)
173+
with source_info_util.user_context(
174+
eqn.source_info.traceback, name_stack=source_info.name_stack
175+
):
176+
if "jaxpr" in eqn.params:
177+
result += _roofline_interpreter(
178+
util.wrap_name(f_name, eqn.primitive.name),
179+
eqn.params["jaxpr"],
180+
mesh,
181+
pin_lhs_in_vmem=pin_lhs_in_vmem,
182+
pin_rhs_in_vmem=pin_rhs_in_vmem,
183+
)
184+
else:
185+
if eqn.primitive not in _rooflines:
186+
msg = f"No roofline rule for {eqn.primitive}."
187+
for attr in dir(eqn):
188+
if not attr.startswith("_"):
189+
msg += f"\n{attr}: {getattr(eqn, attr)}"
190+
raise NotImplementedError(msg)
191+
rule = _rooflines[eqn.primitive]
192+
result += rule(
193+
RooflineRuleContext(
194+
name_stack=source_info.name_stack,
195+
primitive=eqn.primitive,
196+
avals_in=map(aval, eqn.invars),
197+
avals_out=map(aval, eqn.outvars),
198+
jaxpr_eqn_ctx=eqn.ctx,
199+
mesh=mesh,
200+
pin_lhs_in_vmem=pin_lhs_in_vmem,
201+
pin_rhs_in_vmem=pin_rhs_in_vmem,
202+
),
203+
*map(read, eqn.invars),
204+
**eqn.params,
205+
)
206+
207+
map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars))
208+
core.clean_up_dead_vars(eqn, env, last_used)
209+
result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes())
210+
211+
return result
212+
213+
214+
def _f_with_vjp(f: Callable):
215+
@util.wraps(f)
216+
def wrapped(*args):
217+
primals, f_vjp = api.vjp(f, *args)
218+
return f_vjp(tree_map(jnp.bfloat16, primals))
219+
220+
return wrapped
221+
222+
223+
def roofline(
224+
f: Callable,
225+
mesh: Mesh | AbstractMesh,
226+
in_specs: shard_map.Specs,
227+
out_specs: shard_map.Specs,
228+
*,
229+
pin_lhs_in_vmem: bool = False,
230+
pin_rhs_in_vmem: bool = False,
231+
vjp: bool = False,
232+
print_jaxpr: bool = False,
233+
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]:
234+
@util.wraps(f)
235+
@traceback_util.api_boundary
236+
def wrapped(*args):
237+
wrapped_f = shard_map.shard_map(f, mesh, in_specs, out_specs)
238+
if vjp:
239+
wrapped_f = _f_with_vjp(wrapped_f)
240+
241+
jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args)
242+
243+
def make_sharded_shape_dtype_struct(
244+
shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs
245+
) -> api.ShapeDtypeStruct:
246+
return api.ShapeDtypeStruct(
247+
shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec)
248+
)
249+
250+
out_specs_flat = broadcast_prefix(out_specs, out_shapes)
251+
flat_out_shapes, treedef = tree_flatten(out_shapes)
252+
flat_out_shapes = map(
253+
make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat
254+
)
255+
out_shapes = tree_unflatten(treedef, flat_out_shapes)
256+
257+
used_outputs = (True,) * len(jaxpr.jaxpr.outvars)
258+
jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs)
259+
try:
260+
jaxpr = [e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p][
261+
-1
262+
].params["jaxpr"]
263+
except KeyError:
264+
raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.")
265+
266+
if print_jaxpr:
267+
print(jaxpr)
268+
269+
return out_shapes, _roofline_interpreter(
270+
util.fun_qual_name(f),
271+
jaxpr,
272+
mesh,
273+
pin_lhs_in_vmem=pin_lhs_in_vmem,
274+
pin_rhs_in_vmem=pin_rhs_in_vmem,
275+
)
276+
277+
return wrapped
278+
279+
280+
def register_roofline(prim: core.Primitive):
281+
def register(rule: _RooflineRule):
282+
_rooflines[prim] = rule
283+
return rule
284+
285+
return register
286+
287+
288+
def register_standard_roofline(prim: core.Primitive):
289+
def standard_rule(ctx: RooflineRuleContext, *args, **kwargs):
290+
return RooflineResult.zeros()
291+
292+
_rooflines[prim] = standard_rule
293+
294+
295+
def roofline_and_grad(
296+
f: Callable,
297+
mesh: Mesh | AbstractMesh,
298+
in_specs: shard_map.Specs,
299+
out_specs: shard_map.Specs,
300+
*,
301+
pin_lhs_in_vmem: bool = False,
302+
pin_rhs_in_vmem: bool = False,
303+
print_jaxpr: bool = False,
304+
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]:
305+
@util.wraps(f)
306+
@traceback_util.api_boundary
307+
def wrapped(*args):
308+
primal_shapes, fwd_result = roofline(
309+
f,
310+
mesh,
311+
in_specs,
312+
out_specs,
313+
pin_lhs_in_vmem=pin_lhs_in_vmem,
314+
pin_rhs_in_vmem=pin_rhs_in_vmem,
315+
print_jaxpr=print_jaxpr,
316+
)(*args)
317+
318+
return (
319+
primal_shapes,
320+
fwd_result,
321+
roofline(
322+
f,
323+
mesh,
324+
in_specs,
325+
out_specs,
326+
pin_lhs_in_vmem=pin_lhs_in_vmem,
327+
pin_rhs_in_vmem=pin_rhs_in_vmem,
328+
vjp=True,
329+
print_jaxpr=print_jaxpr,
330+
)(
331+
*tree_map(
332+
lambda x: api.ShapeDtypeStruct(
333+
x.shape,
334+
jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16,
335+
sharding=x.sharding,
336+
),
337+
args,
338+
)
339+
)[1],
340+
)
341+
342+
return wrapped

0 commit comments

Comments
 (0)