|
| 1 | +# Copyright 2024 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Helper tool for automatic cost estimation.""" |
| 15 | +import dataclasses |
| 16 | +import math |
| 17 | +from typing import Any, Sequence |
| 18 | + |
| 19 | +from jax._src import core as jax_core |
| 20 | +from jax._src.pallas import core as pallas_core |
| 21 | +from jax._src import linear_util as lu |
| 22 | +from jax._src.interpreters import partial_eval as pe |
| 23 | +from jax._src.util import safe_map |
| 24 | +from jax._src.util import safe_zip |
| 25 | +from jax._src.lax import lax |
| 26 | + |
| 27 | +map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin |
| 28 | +zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin |
| 29 | + |
| 30 | +_cost_rules = {} |
| 31 | + |
| 32 | +@dataclasses.dataclass(frozen=True) |
| 33 | +class CostEstimate: |
| 34 | + flops: int |
| 35 | + transcendentals: int |
| 36 | + bytes_accessed: int |
| 37 | + |
| 38 | + def __add__(self, other: 'CostEstimate') -> 'CostEstimate': |
| 39 | + return CostEstimate( |
| 40 | + flops=self.flops + other.flops, |
| 41 | + transcendentals=self.transcendentals + other.transcendentals, |
| 42 | + bytes_accessed=self.bytes_accessed + other.bytes_accessed, |
| 43 | + ) |
| 44 | + |
| 45 | +def register_cost_rule(primitive: jax_core.Primitive, rule): |
| 46 | + _cost_rules[primitive] = rule |
| 47 | + |
| 48 | +@dataclasses.dataclass(frozen=True) |
| 49 | +class Context: |
| 50 | + avals_in: Sequence[Any] |
| 51 | + avals_out: Sequence[Any] |
| 52 | + |
| 53 | +def cost_estimate_jaxpr( |
| 54 | + jaxpr: jax_core.ClosedJaxpr, |
| 55 | +) -> pallas_core.CostEstimate: |
| 56 | + """Returns the cost estimate for the given Jaxpr.""" |
| 57 | + jaxpr, _ = jaxpr.jaxpr, jaxpr.consts |
| 58 | + total_cost = CostEstimate(flops=0, transcendentals=0, bytes_accessed=0) |
| 59 | + |
| 60 | + for eqn in jaxpr.eqns: |
| 61 | + _, bind_params = eqn.primitive.get_bind_params(eqn.params) |
| 62 | + rule = _cost_rules.get(eqn.primitive, None) |
| 63 | + if rule is not None: |
| 64 | + context = Context(avals_in=[v.aval for v in eqn.invars], |
| 65 | + avals_out=[v.aval for v in eqn.outvars]) |
| 66 | + op_cost = rule(context, **bind_params) |
| 67 | + total_cost = total_cost + op_cost |
| 68 | + return pallas_core.CostEstimate( |
| 69 | + flops=total_cost.flops, |
| 70 | + transcendentals=total_cost.transcendentals, |
| 71 | + bytes_accessed=total_cost.bytes_accessed, |
| 72 | + ) |
| 73 | + |
| 74 | +def cost_estimate(fun, *args) -> pallas_core.CostEstimate: |
| 75 | + """Computes a cost estimate for the given function. |
| 76 | +
|
| 77 | + Args: |
| 78 | + fun: The function to compute the cost estimate for. |
| 79 | + *args: The arguments to the function. Can be jax.ShapeDtypeStruct or |
| 80 | + jax.Array. |
| 81 | +
|
| 82 | + Returns: |
| 83 | + A pallas_core.CostEstimate object containing the cost estimate. |
| 84 | + """ |
| 85 | + wrapped_fun = lu.wrap_init(lambda *args, **kwargs: (fun(*args, **kwargs),)) |
| 86 | + avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in args] |
| 87 | + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) |
| 88 | + estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) |
| 89 | + input_bytes = sum(math.prod(a.shape) * a.dtype.itemsize for a in args) |
| 90 | + output_bytes = sum( |
| 91 | + math.prod(a.aval.shape) * a.aval.dtype.itemsize for a in jaxpr.outvars) |
| 92 | + return pallas_core.CostEstimate( |
| 93 | + flops=estimate.flops, |
| 94 | + transcendentals=estimate.transcendentals, |
| 95 | + bytes_accessed=estimate.bytes_accessed + input_bytes + output_bytes, |
| 96 | + ) |
| 97 | + |
| 98 | +def binary_cost_rule(ctx: Context, **_) -> CostEstimate: |
| 99 | + aval_out, = ctx.avals_out |
| 100 | + out_flops = math.prod(aval_out.shape) |
| 101 | + return CostEstimate( |
| 102 | + flops=out_flops, |
| 103 | + transcendentals=0, |
| 104 | + bytes_accessed=0, |
| 105 | + ) |
| 106 | +BINARY_OPS = [ |
| 107 | + lax.add_p, |
| 108 | + lax.mul_p, |
| 109 | + lax.sub_p, |
| 110 | + lax.div_p, |
| 111 | + lax.min_p, |
| 112 | + lax.max_p, |
| 113 | + lax.or_p, |
| 114 | + lax.and_p, |
| 115 | + lax.xor_p, |
| 116 | +] |
| 117 | +for op in BINARY_OPS: |
| 118 | + register_cost_rule(op, binary_cost_rule) |
| 119 | + |
| 120 | + |
| 121 | +def unary_cost_rule(transcendental: bool): |
| 122 | + def cost_rule(ctx: Context, **_) -> CostEstimate: |
| 123 | + x_aval, = ctx.avals_in |
| 124 | + new_flops = 0 |
| 125 | + new_transcendentals = 0 |
| 126 | + if transcendental: |
| 127 | + new_transcendentals += math.prod(x_aval.shape) |
| 128 | + else: |
| 129 | + new_flops += math.prod(x_aval.shape) |
| 130 | + return CostEstimate( |
| 131 | + flops=new_flops, |
| 132 | + transcendentals=new_transcendentals, |
| 133 | + bytes_accessed=0, |
| 134 | + ) |
| 135 | + return cost_rule |
| 136 | + |
| 137 | +UN_OPS = [ |
| 138 | + lax.neg_p, |
| 139 | + lax.floor_p, |
| 140 | + lax.ceil_p, |
| 141 | + lax.round_p, |
| 142 | + lax.not_p, |
| 143 | +] |
| 144 | +for op in UN_OPS: |
| 145 | + register_cost_rule(op, unary_cost_rule(transcendental=False)) |
| 146 | + |
| 147 | +TRANSCENDENTAL_OPS = [ |
| 148 | + lax.cos_p, |
| 149 | + lax.sin_p, |
| 150 | + lax.tan_p, |
| 151 | + lax.sinh_p, |
| 152 | + lax.cosh_p, |
| 153 | + lax.tanh_p, |
| 154 | + lax.acos_p, |
| 155 | + lax.asin_p, |
| 156 | + lax.atan_p, |
| 157 | + lax.exp_p, |
| 158 | + lax.log_p, |
| 159 | + lax.logistic_p, |
| 160 | + lax.sqrt_p, |
| 161 | +] |
| 162 | +for op in TRANSCENDENTAL_OPS: |
| 163 | + register_cost_rule(op, unary_cost_rule(transcendental=True)) |
| 164 | + |
| 165 | +def _integer_pow_cost_rule(ctx: Context, *, y: int) -> CostEstimate: |
| 166 | + x_aval, = ctx.avals_in |
| 167 | + num_elements = math.prod(x_aval.shape) |
| 168 | + if y == 0 or y == 1: |
| 169 | + # No flops, the result is 0 or a copy of the input. |
| 170 | + cost_per_element = 0 |
| 171 | + else: |
| 172 | + # We assume integer pow is implemented using repeated squaring. |
| 173 | + # The cost is log(y) squarings, plus one multiply per non-zero bit. |
| 174 | + highest_bit = math.floor(math.log(y, 2)) |
| 175 | + cost_per_element = highest_bit + y.bit_count() |
| 176 | + return CostEstimate( |
| 177 | + flops=num_elements * cost_per_element, |
| 178 | + transcendentals=0, |
| 179 | + bytes_accessed=0, |
| 180 | + ) |
| 181 | +register_cost_rule(lax.integer_pow_p, _integer_pow_cost_rule) |
| 182 | + |
| 183 | +def dot_general_cost_rule(ctx: Context, |
| 184 | + dimension_numbers: lax.DotDimensionNumbers, |
| 185 | + **_) -> CostEstimate: |
| 186 | + x_aval, y_aval = ctx.avals_in |
| 187 | + x_shape, y_shape = x_aval.shape, y_aval.shape |
| 188 | + (lhs_contracting_dims, rhs_contracting_dims), ( |
| 189 | + lhs_batch_dims, rhs_batch_dims) = dimension_numbers |
| 190 | + assert len(lhs_contracting_dims) == len(rhs_contracting_dims) |
| 191 | + assert len(lhs_batch_dims) == len(rhs_batch_dims) |
| 192 | + flops = 1 |
| 193 | + # Flops along a contracting dim is 2*dim (addition and multiplication) |
| 194 | + for i in range(len(lhs_contracting_dims)): |
| 195 | + lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i] |
| 196 | + assert x_shape[lhs_dim] == y_shape[rhs_dim] |
| 197 | + flops *= 2 * x_shape[lhs_dim] |
| 198 | + # Now we handle all other dimensions. |
| 199 | + for i, lhs_dim in enumerate(x_shape): |
| 200 | + if i in lhs_contracting_dims: |
| 201 | + continue |
| 202 | + flops *= lhs_dim |
| 203 | + for i, rhs_dim in enumerate(y_shape): |
| 204 | + if i in rhs_contracting_dims: |
| 205 | + continue |
| 206 | + # Don't double-count batch dims (we already counted for LHS) |
| 207 | + if i in rhs_batch_dims: |
| 208 | + continue |
| 209 | + flops *= rhs_dim |
| 210 | + return CostEstimate( |
| 211 | + flops=flops, |
| 212 | + transcendentals=0, |
| 213 | + bytes_accessed=0, |
| 214 | + ) |
| 215 | +register_cost_rule(lax.dot_general_p, dot_general_cost_rule) |
0 commit comments