|
| 1 | +""" |
| 2 | +MIT License |
| 3 | +
|
| 4 | +Copyright (c) 2025 Alexander Gräfe |
| 5 | +
|
| 6 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | +of this software and associated documentation files (the "Software"), to deal |
| 8 | +in the Software without restriction, including without limitation the rights |
| 9 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 10 | +copies of the Software, and to permit persons to whom the Software is |
| 11 | +furnished to do so, subject to the following conditions: |
| 12 | +
|
| 13 | +The above copyright notice and this permission notice shall be included in all |
| 14 | +copies or substantial portions of the Software. |
| 15 | +
|
| 16 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | +SOFTWARE. |
| 23 | +
|
| 24 | +Tools for mixer precision training. Methods and general code architecture are from jmp https://github.com/google-deepmind/jmp. This can be seen as a port and extension of JMP tot equinox. |
| 25 | +""" |
| 26 | + |
| 27 | +"""Filtering tools for mixer precision training.""" |
| 28 | + |
| 29 | + |
| 30 | +import jax |
| 31 | +import jax.numpy as jnp |
| 32 | +import equinox as eqx |
| 33 | + |
| 34 | +import optax |
| 35 | + |
| 36 | +import cast as cast |
| 37 | +import loss_scaling as loss_scaling |
| 38 | + |
| 39 | +from jaxtyping import PyTree, Bool |
| 40 | + |
| 41 | + |
| 42 | +def select_tree(pred: jnp.ndarray, a: PyTree, b: PyTree) -> PyTree: |
| 43 | + """ |
| 44 | + Selects elements from one of two pytrees based on a scalar boolean predicate. |
| 45 | +
|
| 46 | + This function traverses two input pytrees (`a` and `b`) and selects elements |
| 47 | + from either `a` or `b` based on the value of the scalar boolean `pred`. If |
| 48 | + `pred` is `True`, elements from `a` are selected; otherwise, elements from `b` |
| 49 | + are selected. Non-array elements in the pytrees are taken directly from `a`. |
| 50 | +
|
| 51 | + Args: |
| 52 | + pred (jnp.ndarray): A scalar boolean array (`jnp.bool_`) that determines |
| 53 | + which pytree to select elements from. |
| 54 | + a (PyTree): The first pytree to select elements from. |
| 55 | + b (PyTree): The second pytree to select elements from. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + PyTree: A new pytree with elements selected from `a` or `b` based on `pred`. |
| 59 | +
|
| 60 | + Raises: |
| 61 | + AssertionError: If `pred` is not a scalar boolean array (`jnp.bool_`). |
| 62 | + """ |
| 63 | + """Selects a pytree based on the given predicate.""" |
| 64 | + assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar" |
| 65 | + def _select_leaf(x1, x2): |
| 66 | + if eqx.is_array(x1): |
| 67 | + return jax.lax.select(pred, x1, x2) |
| 68 | + else: |
| 69 | + return x1 |
| 70 | + |
| 71 | + return jax.tree_util.tree_map(_select_leaf, a, b) |
| 72 | + |
| 73 | + |
| 74 | +def filter_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False) -> PyTree: |
| 75 | + """ |
| 76 | + Filters the gradients of a function based on a predicate. |
| 77 | +
|
| 78 | + This function computes the gradients of the given function `func` with respect |
| 79 | + to its arguments (`args` and `kwargs`). It then filters the gradients based on |
| 80 | + a predicate function that checks whether the gradients are finite. The filtered |
| 81 | + gradients are returned as a new pytree. |
| 82 | +
|
| 83 | + Args: |
| 84 | + func (callable): The function to compute gradients for. This function must only use pytrees as parameters! |
| 85 | + has_aux (bool): If True, the function is expected to return auxiliary values along with the gradients. |
| 86 | + Returns: |
| 87 | + callable: A function that computes the filtered gradients of `func`. It returns the grad, the new loss scaling, and a boolean indicating whether the gradients are finite (and the aux-value if has_aux is true). |
| 88 | + """ |
| 89 | + def wrapper(*args, **kwargs): |
| 90 | + args_cast = tuple([cast.cast_to_half_precision(x) for x in args]) |
| 91 | + kwargs_cast = {k: cast.cast_to_half_precision(v) for k, v in kwargs.items()} |
| 92 | + |
| 93 | + func_scaled = loss_scaling.scaled(func, scaling) |
| 94 | + |
| 95 | + dfunc_scaled = eqx.filter_grad(func_scaled, has_aux=has_aux) |
| 96 | + |
| 97 | + if has_aux: |
| 98 | + aux, grad = dfunc_scaled(*args_cast, **kwargs_cast) |
| 99 | + grads_finite = loss_scaling.all_finite(grad) |
| 100 | + loss_scaling_new = scaling.adjust(grads_finite) |
| 101 | + grad = loss_scaling_new.unscale(grad) |
| 102 | + return aux, loss_scaling_new, grads_finite, grad |
| 103 | + else: |
| 104 | + grad = dfunc_scaled(*args_cast, **kwargs_cast) |
| 105 | + grads_finite = loss_scaling.all_finite(grad) |
| 106 | + loss_scaling_new = scaling.adjust(grads_finite) |
| 107 | + grad = loss_scaling_new.unscale(grad) |
| 108 | + return loss_scaling_new, grads_finite, grad |
| 109 | + |
| 110 | + return wrapper |
| 111 | + |
| 112 | + |
| 113 | +def filter_value_and_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False) -> PyTree: |
| 114 | + """ |
| 115 | + Wraps a function to compute its value and gradient with support for mixed precision |
| 116 | + and dynamic loss scaling. |
| 117 | + Args: |
| 118 | + func (Callable): The function for which the value and gradient are to be computed. |
| 119 | + scaling (loss_scaling.DynamicLossScaling): An instance of DynamicLossScaling to |
| 120 | + handle loss scaling and gradient unscaling. |
| 121 | + has_aux (bool, optional): Indicates whether the function `func` returns auxiliary |
| 122 | + outputs along with the main value. Defaults to False. |
| 123 | + Returns: |
| 124 | + Callable: A wrapped function that computes the value, gradient, and additional |
| 125 | + information: |
| 126 | + - If `has_aux` is True: |
| 127 | + ((value, aux), loss_scaling_new, grads_finite, grad) |
| 128 | + - If `has_aux` is False: |
| 129 | + (value, loss_scaling_new, grads_finite, grad) |
| 130 | + Where: |
| 131 | + - `value`: The computed value of the function. |
| 132 | + - `aux`: Auxiliary outputs returned by the function (if `has_aux` is True). |
| 133 | + - `loss_scaling_new`: The updated loss scaling object. |
| 134 | + - `grads_finite`: A boolean indicating whether all gradients are finite. |
| 135 | + - `grad`: The computed gradients, unscaled. |
| 136 | + """ |
| 137 | + |
| 138 | + def wrapper(*args, **kwargs): |
| 139 | + args_cast = tuple([cast.cast_to_half_precision(x) for x in args]) |
| 140 | + kwargs_cast = {k: cast.cast_to_half_precision(v) for k, v in kwargs.items()} |
| 141 | + |
| 142 | + func_scaled = loss_scaling.scaled(func, scaling) |
| 143 | + |
| 144 | + dfunc_scaled = eqx.filter_value_and_grad(func_scaled, has_aux=has_aux) |
| 145 | + |
| 146 | + if has_aux: |
| 147 | + (value, aux), grad = dfunc_scaled(*args_cast, **kwargs_cast) |
| 148 | + grads_finite = loss_scaling.all_finite(grad) |
| 149 | + loss_scaling_new = scaling.adjust(grads_finite) |
| 150 | + grad = loss_scaling_new.unscale(grad) |
| 151 | + value = loss_scaling_new.unscale(value) |
| 152 | + return (value, aux), loss_scaling_new, grads_finite, grad |
| 153 | + else: |
| 154 | + value, grad = dfunc_scaled(*args_cast, **kwargs_cast) |
| 155 | + grads_finite = loss_scaling.all_finite(grad) |
| 156 | + loss_scaling_new = scaling.adjust(grads_finite) |
| 157 | + grad = loss_scaling_new.unscale(grad) |
| 158 | + value = loss_scaling_new.unscale(value) |
| 159 | + return value, loss_scaling_new, grads_finite, grad |
| 160 | + |
| 161 | + return wrapper |
| 162 | + |
| 163 | + |
| 164 | +def optimizer_update(model: PyTree, optimizer: optax.GradientTransformation, optimizer_state: PyTree, grads: PyTree, grads_finite: Bool): |
| 165 | + |
| 166 | + # optimizer step |
| 167 | + updates, new_optimizer_state = optimizer.update( |
| 168 | + grads, optimizer_state, eqx.filter(model, eqx.is_array) |
| 169 | + ) |
| 170 | + new_model = eqx.apply_updates(model, updates) |
| 171 | + |
| 172 | + # only apply updates to the model and optimizer state if gradients are finite |
| 173 | + model = select_tree(grads_finite, new_model, model) |
| 174 | + optimizer_state = select_tree(grads_finite, new_optimizer_state, optimizer_state) |
| 175 | + |
| 176 | + return model, optimizer_state |
0 commit comments