|
| 1 | +# Copyright 2024 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +from dataclasses import dataclass, field |
| 17 | +from typing import Any, Callable, Protocol, Sequence |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +import jax.numpy as jnp |
| 21 | +from jax.sharding import NamedSharding |
| 22 | +from jax._src import api |
| 23 | +from jax._src import core |
| 24 | +from jax._src import source_info_util |
| 25 | +from jax._src import traceback_util |
| 26 | +from jax._src import util |
| 27 | +from jax._src.api import make_jaxpr |
| 28 | +from jax._src.interpreters.partial_eval import dce_jaxpr |
| 29 | +from jax._src.interpreters.xla import abstractify |
| 30 | +from jax._src.mesh import AbstractMesh, Mesh |
| 31 | +from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map |
| 32 | +from jax.experimental import shard_map |
| 33 | + |
| 34 | + |
| 35 | +ShapeDtypeStructTree = Any |
| 36 | + |
| 37 | + |
| 38 | +map = util.safe_map |
| 39 | + |
| 40 | + |
| 41 | +@dataclass(frozen=True, slots=True, kw_only=True) |
| 42 | +class RooflineRuleContext: |
| 43 | + name_stack: source_info_util.NameStack |
| 44 | + primitive: core.Primitive |
| 45 | + avals_in: Sequence[core.AbstractValue] |
| 46 | + avals_out: Sequence[core.AbstractValue] |
| 47 | + jaxpr_eqn_ctx: core.JaxprEqnContext |
| 48 | + mesh: Mesh | AbstractMesh |
| 49 | + pin_lhs_in_vmem: bool |
| 50 | + pin_rhs_in_vmem: bool |
| 51 | + |
| 52 | + |
| 53 | +@dataclass(frozen=True, slots=True, kw_only=True) |
| 54 | +class RooflineShape: |
| 55 | + shape: tuple[int, ...] |
| 56 | + dtype: np.dtype |
| 57 | + |
| 58 | + @classmethod |
| 59 | + def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape": |
| 60 | + if not isinstance(aval, core.ShapedArray): |
| 61 | + raise TypeError(f"Expected ShapedArray, got {type(aval)}.") |
| 62 | + if not isinstance(aval.dtype, np.dtype): |
| 63 | + raise TypeError(f"Expected numpy dtype, got {type(aval.dtype)}.") |
| 64 | + return cls(shape=aval.shape, dtype=aval.dtype) |
| 65 | + |
| 66 | + @property |
| 67 | + def size(self) -> int: |
| 68 | + return int(np.prod(self.shape)) |
| 69 | + |
| 70 | + @property |
| 71 | + def bytes(self) -> int: |
| 72 | + return int(self.size * self.dtype.itemsize) |
| 73 | + |
| 74 | + @classmethod |
| 75 | + def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int: |
| 76 | + return sum(cls.from_aval(aval).bytes for aval in avals) |
| 77 | + |
| 78 | + |
| 79 | +@dataclass(frozen=True, slots=True, kw_only=True) |
| 80 | +class RooflineResult: |
| 81 | + flops: int = 0 |
| 82 | + ici_bytes: dict[str, int] = field(default_factory=dict) |
| 83 | + ici_latency: dict[str, int] = field(default_factory=dict) |
| 84 | + hbm_bytes: int = 0 |
| 85 | + peak_hbm_bytes: int = 0 |
| 86 | + |
| 87 | + @classmethod |
| 88 | + def zeros(cls) -> "RooflineResult": |
| 89 | + return cls() |
| 90 | + |
| 91 | + def __add__(self, other: "RooflineResult") -> "RooflineResult": |
| 92 | + def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: |
| 93 | + return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)} |
| 94 | + |
| 95 | + return RooflineResult( |
| 96 | + flops=self.flops + other.flops, |
| 97 | + ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes), |
| 98 | + ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency), |
| 99 | + hbm_bytes=self.hbm_bytes + other.hbm_bytes, |
| 100 | + peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes), |
| 101 | + ) |
| 102 | + |
| 103 | + def __mul__(self, constant: int | float) -> "RooflineResult": |
| 104 | + return RooflineResult( |
| 105 | + flops=int(self.flops * constant), |
| 106 | + ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()}, |
| 107 | + ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()}, |
| 108 | + hbm_bytes=int(self.hbm_bytes * constant), |
| 109 | + peak_hbm_bytes=int(self.peak_hbm_bytes * constant), |
| 110 | + ) |
| 111 | + |
| 112 | + def __rmul__(self, constant: int | float) -> "RooflineResult": |
| 113 | + return self.__mul__(constant) |
| 114 | + |
| 115 | + |
| 116 | +class _RooflineRule(Protocol): |
| 117 | + def __call__( |
| 118 | + self, ctx: RooflineRuleContext, *args: RooflineShape, **kw |
| 119 | + ) -> RooflineResult: ... |
| 120 | + |
| 121 | + |
| 122 | +_rooflines: dict[core.Primitive, _RooflineRule] = {} |
| 123 | + |
| 124 | + |
| 125 | +def _roofline_interpreter( |
| 126 | + f_name: str, |
| 127 | + jaxpr: core.Jaxpr, |
| 128 | + mesh: Mesh | AbstractMesh, |
| 129 | + *, |
| 130 | + pin_lhs_in_vmem: bool = False, |
| 131 | + pin_rhs_in_vmem: bool = False, |
| 132 | +) -> RooflineResult: |
| 133 | + name_stack = source_info_util.new_name_stack(util.wrap_name(f_name, "roofline")) |
| 134 | + |
| 135 | + result = RooflineResult.zeros() |
| 136 | + |
| 137 | + env: dict[core.Var, RooflineShape] = {} |
| 138 | + |
| 139 | + def write(v: core.Var, node: RooflineShape): |
| 140 | + assert node is not None |
| 141 | + env[v] = node |
| 142 | + |
| 143 | + def read(v: core.Atom) -> RooflineShape: |
| 144 | + if type(v) is core.Literal: |
| 145 | + return RooflineShape.from_aval(abstractify(v.val)) |
| 146 | + else: |
| 147 | + assert isinstance(v, core.Var) |
| 148 | + return env[v] |
| 149 | + |
| 150 | + def aval(v: core.Atom) -> core.AbstractValue: |
| 151 | + if type(v) is core.Literal: |
| 152 | + return abstractify(v.val) |
| 153 | + else: |
| 154 | + return v.aval |
| 155 | + |
| 156 | + def calculate_peak_hbm_bytes() -> int: |
| 157 | + return int( |
| 158 | + sum(np.prod(shape.shape) * shape.dtype.itemsize for shape in env.values()) |
| 159 | + ) |
| 160 | + |
| 161 | + make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x)) |
| 162 | + map( |
| 163 | + write, |
| 164 | + jaxpr.constvars, |
| 165 | + map(make_roofline_shape, jaxpr.constvars), |
| 166 | + ) |
| 167 | + map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars)) |
| 168 | + last_used = core.last_used(jaxpr) |
| 169 | + for eqn in jaxpr.eqns: |
| 170 | + source_info = eqn.source_info.replace( |
| 171 | + name_stack=name_stack + eqn.source_info.name_stack |
| 172 | + ) |
| 173 | + with source_info_util.user_context( |
| 174 | + eqn.source_info.traceback, name_stack=source_info.name_stack |
| 175 | + ): |
| 176 | + if "jaxpr" in eqn.params: |
| 177 | + result += _roofline_interpreter( |
| 178 | + util.wrap_name(f_name, eqn.primitive.name), |
| 179 | + eqn.params["jaxpr"], |
| 180 | + mesh, |
| 181 | + pin_lhs_in_vmem=pin_lhs_in_vmem, |
| 182 | + pin_rhs_in_vmem=pin_rhs_in_vmem, |
| 183 | + ) |
| 184 | + else: |
| 185 | + if eqn.primitive not in _rooflines: |
| 186 | + msg = f"No roofline rule for {eqn.primitive}." |
| 187 | + for attr in dir(eqn): |
| 188 | + if not attr.startswith("_"): |
| 189 | + msg += f"\n{attr}: {getattr(eqn, attr)}" |
| 190 | + raise NotImplementedError(msg) |
| 191 | + rule = _rooflines[eqn.primitive] |
| 192 | + result += rule( |
| 193 | + RooflineRuleContext( |
| 194 | + name_stack=source_info.name_stack, |
| 195 | + primitive=eqn.primitive, |
| 196 | + avals_in=map(aval, eqn.invars), |
| 197 | + avals_out=map(aval, eqn.outvars), |
| 198 | + jaxpr_eqn_ctx=eqn.ctx, |
| 199 | + mesh=mesh, |
| 200 | + pin_lhs_in_vmem=pin_lhs_in_vmem, |
| 201 | + pin_rhs_in_vmem=pin_rhs_in_vmem, |
| 202 | + ), |
| 203 | + *map(read, eqn.invars), |
| 204 | + **eqn.params, |
| 205 | + ) |
| 206 | + |
| 207 | + map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars)) |
| 208 | + core.clean_up_dead_vars(eqn, env, last_used) |
| 209 | + result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes()) |
| 210 | + |
| 211 | + return result |
| 212 | + |
| 213 | + |
| 214 | +def _f_with_vjp(f: Callable): |
| 215 | + @util.wraps(f) |
| 216 | + def wrapped(*args): |
| 217 | + primals, f_vjp = api.vjp(f, *args) |
| 218 | + return f_vjp(tree_map(jnp.bfloat16, primals)) |
| 219 | + |
| 220 | + return wrapped |
| 221 | + |
| 222 | + |
| 223 | +def roofline( |
| 224 | + f: Callable, |
| 225 | + mesh: Mesh | AbstractMesh, |
| 226 | + in_specs: shard_map.Specs, |
| 227 | + out_specs: shard_map.Specs, |
| 228 | + *, |
| 229 | + pin_lhs_in_vmem: bool = False, |
| 230 | + pin_rhs_in_vmem: bool = False, |
| 231 | + vjp: bool = False, |
| 232 | + print_jaxpr: bool = False, |
| 233 | +) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]: |
| 234 | + @util.wraps(f) |
| 235 | + @traceback_util.api_boundary |
| 236 | + def wrapped(*args): |
| 237 | + wrapped_f = shard_map.shard_map(f, mesh, in_specs, out_specs) |
| 238 | + if vjp: |
| 239 | + wrapped_f = _f_with_vjp(wrapped_f) |
| 240 | + |
| 241 | + jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args) |
| 242 | + |
| 243 | + def make_sharded_shape_dtype_struct( |
| 244 | + shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs |
| 245 | + ) -> api.ShapeDtypeStruct: |
| 246 | + return api.ShapeDtypeStruct( |
| 247 | + shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec) |
| 248 | + ) |
| 249 | + |
| 250 | + out_specs_flat = broadcast_prefix(out_specs, out_shapes) |
| 251 | + flat_out_shapes, treedef = tree_flatten(out_shapes) |
| 252 | + flat_out_shapes = map( |
| 253 | + make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat |
| 254 | + ) |
| 255 | + out_shapes = tree_unflatten(treedef, flat_out_shapes) |
| 256 | + |
| 257 | + used_outputs = (True,) * len(jaxpr.jaxpr.outvars) |
| 258 | + jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs) |
| 259 | + try: |
| 260 | + jaxpr = [e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p][ |
| 261 | + -1 |
| 262 | + ].params["jaxpr"] |
| 263 | + except KeyError: |
| 264 | + raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.") |
| 265 | + |
| 266 | + if print_jaxpr: |
| 267 | + print(jaxpr) |
| 268 | + |
| 269 | + return out_shapes, _roofline_interpreter( |
| 270 | + util.fun_qual_name(f), |
| 271 | + jaxpr, |
| 272 | + mesh, |
| 273 | + pin_lhs_in_vmem=pin_lhs_in_vmem, |
| 274 | + pin_rhs_in_vmem=pin_rhs_in_vmem, |
| 275 | + ) |
| 276 | + |
| 277 | + return wrapped |
| 278 | + |
| 279 | + |
| 280 | +def register_roofline(prim: core.Primitive): |
| 281 | + def register(rule: _RooflineRule): |
| 282 | + _rooflines[prim] = rule |
| 283 | + return rule |
| 284 | + |
| 285 | + return register |
| 286 | + |
| 287 | + |
| 288 | +def register_standard_roofline(prim: core.Primitive): |
| 289 | + def standard_rule(ctx: RooflineRuleContext, *args, **kwargs): |
| 290 | + return RooflineResult.zeros() |
| 291 | + |
| 292 | + _rooflines[prim] = standard_rule |
| 293 | + |
| 294 | + |
| 295 | +def roofline_and_grad( |
| 296 | + f: Callable, |
| 297 | + mesh: Mesh | AbstractMesh, |
| 298 | + in_specs: shard_map.Specs, |
| 299 | + out_specs: shard_map.Specs, |
| 300 | + *, |
| 301 | + pin_lhs_in_vmem: bool = False, |
| 302 | + pin_rhs_in_vmem: bool = False, |
| 303 | + print_jaxpr: bool = False, |
| 304 | +) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]: |
| 305 | + @util.wraps(f) |
| 306 | + @traceback_util.api_boundary |
| 307 | + def wrapped(*args): |
| 308 | + primal_shapes, fwd_result = roofline( |
| 309 | + f, |
| 310 | + mesh, |
| 311 | + in_specs, |
| 312 | + out_specs, |
| 313 | + pin_lhs_in_vmem=pin_lhs_in_vmem, |
| 314 | + pin_rhs_in_vmem=pin_rhs_in_vmem, |
| 315 | + print_jaxpr=print_jaxpr, |
| 316 | + )(*args) |
| 317 | + |
| 318 | + return ( |
| 319 | + primal_shapes, |
| 320 | + fwd_result, |
| 321 | + roofline( |
| 322 | + f, |
| 323 | + mesh, |
| 324 | + in_specs, |
| 325 | + out_specs, |
| 326 | + pin_lhs_in_vmem=pin_lhs_in_vmem, |
| 327 | + pin_rhs_in_vmem=pin_rhs_in_vmem, |
| 328 | + vjp=True, |
| 329 | + print_jaxpr=print_jaxpr, |
| 330 | + )( |
| 331 | + *tree_map( |
| 332 | + lambda x: api.ShapeDtypeStruct( |
| 333 | + x.shape, |
| 334 | + jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16, |
| 335 | + sharding=x.sharding, |
| 336 | + ), |
| 337 | + args, |
| 338 | + ) |
| 339 | + )[1], |
| 340 | + ) |
| 341 | + |
| 342 | + return wrapped |
0 commit comments