Skip to content

Commit adb2807

Browse files
author
Alexander
committed
added code
1 parent 0fe7f44 commit adb2807

File tree

5 files changed

+506
-0
lines changed

5 files changed

+506
-0
lines changed

mpfj/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from . import utils
2+
from . import layers
3+
from . import optimizers
4+
5+
"""
6+
Mixed Precision for JAX (mpfj)
7+
8+
This package provides utilities for mixed precision training in JAX.
9+
"""
10+
11+
__version__ = "0.1.0"
12+
13+
from .dtypes import set_half_precision_datatype

mpfj/cast.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) 2025 Alexander Gräfe
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
24+
Tools for mixer precision training. Methods and general code architecture are from jmp https://github.com/google-deepmind/jmp. This can be seen as a port and extension of JMP tot equinox.
25+
"""
26+
27+
"""
28+
Functions for casting of Pytrees.
29+
"""
30+
31+
import jax
32+
import jax.numpy as jnp
33+
import equinox as eqx
34+
35+
from jaxtyping import Array, Float, Int, PyTree, PRNGKeyArray
36+
37+
from .dtypes import HALF_PRECISION_DATATYPE
38+
39+
def cast_tree(tree: PyTree, dtype):
40+
"""
41+
Casts all array elements in a PyTree to a specified data type.
42+
This function traverses a PyTree and applies a type casting operation to all array elements, leaving non-array elements unchanged.
43+
Args:
44+
tree (PyTree): The input PyTree containing arrays and other objects.
45+
dtype (numpy.dtype or str): The target data type to cast the arrays to.
46+
Returns:
47+
PyTree: A new PyTree with all array elements cast to the specified data type.
48+
"""
49+
50+
def _cast(x):
51+
if eqx.is_array(x):
52+
return x.astype(dtype)
53+
else:
54+
return x
55+
return jax.tree_util.tree_map(_cast, tree)
56+
57+
58+
def cast_to_float32(x: PyTree) -> PyTree:
59+
"""
60+
Cast the input PyTree to `float32` data type.
61+
62+
This function takes a PyTree and casts all its elements to the `float32` data type.
63+
64+
Args:
65+
x (PyTree): The input PyTree containing elements to be cast.
66+
67+
Returns:
68+
PyTree: A new PyTree with all elements cast to `float32`.
69+
"""
70+
"""Cast to float32."""
71+
return cast_tree(x, jnp.float32)
72+
73+
74+
def cast_to_float16(x: PyTree) -> PyTree:
75+
"""
76+
Casts all elements of a PyTree to the float16 data type.
77+
78+
Args:
79+
x (PyTree): A PyTree containing numerical data to be cast to float16.
80+
81+
Returns:
82+
PyTree: A new PyTree with all numerical elements cast to float16.
83+
"""
84+
return cast_tree(x, jnp.float16)
85+
86+
87+
def cast_to_bfloat16(x: PyTree) -> PyTree:
88+
"""
89+
Casts the input PyTree to the bfloat16 data type.
90+
91+
Args:
92+
x (PyTree): A PyTree structure containing arrays or tensors to be cast.
93+
94+
Returns:
95+
PyTree: A PyTree with all arrays or tensors cast to the bfloat16 data type.
96+
"""
97+
return cast_tree(x, jnp.bfloat16)
98+
99+
100+
def cast_to_full_precision(x: PyTree) -> PyTree:
101+
"""
102+
Casts all elements of a PyTree to full precision (float32).
103+
104+
Args:
105+
x (PyTree): The input PyTree containing elements to be cast.
106+
107+
Returns:
108+
PyTree: A new PyTree with all elements cast to float32 precision.
109+
"""
110+
"""Cast to full precision (float32)."""
111+
return cast_tree(x, jnp.float32)
112+
113+
def cast_to_half_precision(x: PyTree) -> PyTree:
114+
"""
115+
Cast the input PyTree to half precision.
116+
117+
This function converts all elements in the input PyTree to the half-precision
118+
datatype (either `float16` or `bfloat16`), depending on the configuration set
119+
by `set_half_precision_datatype`.
120+
121+
Args:
122+
x (PyTree): The input PyTree containing elements to be cast to half precision.
123+
124+
Returns:
125+
PyTree: A new PyTree with all elements cast to the half-precision datatype.
126+
"""
127+
"""Cast to half precision (float16/bfloat16, depending on with what we called set_half_precision_datatype)."""
128+
return cast_tree(x, HALF_PRECISION_DATATYPE)
129+
130+
131+
def force_full_precision(func, return_dtype=jnp.float16):
132+
"""
133+
A decorator to enforce full precision (float32) for the inputs and outputs of a function.
134+
This decorator ensures that all array arguments passed to the decorated function are
135+
converted to float32 precision before the function is executed. Additionally, it converts
136+
the outputs of the function to the specified `return_dtype` if they are arrays.
137+
Args:
138+
func (callable): The function to be decorated.
139+
return_dtype (dtype): The desired data type for the function's output arrays.
140+
Returns:
141+
callable: The wrapped function with enforced input and output precision.
142+
Example:
143+
@force_full_precision
144+
def my_function(x, y):
145+
return x + y
146+
# All array inputs to `my_function` will be cast to float32, and the output
147+
# will be cast to the specified `return_dtype` if it is an array.
148+
"""
149+
150+
def wrapper(*args, **kwargs):
151+
args_full_precision = []
152+
for arg in args:
153+
if eqx.is_array(arg):
154+
args_full_precision.append(arg.astype(jnp.float32))
155+
else:
156+
args_full_precision.append(arg)
157+
args_full_precision = tuple(args_full_precision)
158+
159+
kwargs_full_precision = {}
160+
for key, value in kwargs.items():
161+
if eqx.is_array(value):
162+
kwargs_full_precision[key] = value.astype(jnp.float32)
163+
else:
164+
kwargs_full_precision[key] = value
165+
166+
results = func(*args_full_precision, **kwargs_full_precision)
167+
168+
if type(results) == tuple:
169+
results_converted = []
170+
for r in results:
171+
if eqx.is_array(r):
172+
results_converted.append(r.astype(return_dtype))
173+
else:
174+
results_converted.append(r)
175+
return tuple(results_converted)
176+
elif eqx.is_array(results):
177+
return results.astype(return_dtype)
178+
return results
179+
180+
return wrapper

mpfj/dtypes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import jax.numpy as jnp
2+
3+
HALF_PRECISION_DATATYPE = jnp.float16
4+
5+
def set_half_precision_datatype(datatype):
6+
"""
7+
Set the half precision datatype for the module.
8+
9+
Args:
10+
datatype: The datatype to set as half precision (e.g., jnp.float16).
11+
"""
12+
HALF_PRECISION_DATATYPE = datatype

mpfj/grad_tools.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) 2025 Alexander Gräfe
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
24+
Tools for mixer precision training. Methods and general code architecture are from jmp https://github.com/google-deepmind/jmp. This can be seen as a port and extension of JMP tot equinox.
25+
"""
26+
27+
"""Filtering tools for mixer precision training."""
28+
29+
30+
import jax
31+
import jax.numpy as jnp
32+
import equinox as eqx
33+
34+
import optax
35+
36+
import cast as cast
37+
import loss_scaling as loss_scaling
38+
39+
from jaxtyping import PyTree, Bool
40+
41+
42+
def select_tree(pred: jnp.ndarray, a: PyTree, b: PyTree) -> PyTree:
43+
"""
44+
Selects elements from one of two pytrees based on a scalar boolean predicate.
45+
46+
This function traverses two input pytrees (`a` and `b`) and selects elements
47+
from either `a` or `b` based on the value of the scalar boolean `pred`. If
48+
`pred` is `True`, elements from `a` are selected; otherwise, elements from `b`
49+
are selected. Non-array elements in the pytrees are taken directly from `a`.
50+
51+
Args:
52+
pred (jnp.ndarray): A scalar boolean array (`jnp.bool_`) that determines
53+
which pytree to select elements from.
54+
a (PyTree): The first pytree to select elements from.
55+
b (PyTree): The second pytree to select elements from.
56+
57+
Returns:
58+
PyTree: A new pytree with elements selected from `a` or `b` based on `pred`.
59+
60+
Raises:
61+
AssertionError: If `pred` is not a scalar boolean array (`jnp.bool_`).
62+
"""
63+
"""Selects a pytree based on the given predicate."""
64+
assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar"
65+
def _select_leaf(x1, x2):
66+
if eqx.is_array(x1):
67+
return jax.lax.select(pred, x1, x2)
68+
else:
69+
return x1
70+
71+
return jax.tree_util.tree_map(_select_leaf, a, b)
72+
73+
74+
def filter_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False) -> PyTree:
75+
"""
76+
Filters the gradients of a function based on a predicate.
77+
78+
This function computes the gradients of the given function `func` with respect
79+
to its arguments (`args` and `kwargs`). It then filters the gradients based on
80+
a predicate function that checks whether the gradients are finite. The filtered
81+
gradients are returned as a new pytree.
82+
83+
Args:
84+
func (callable): The function to compute gradients for. This function must only use pytrees as parameters!
85+
has_aux (bool): If True, the function is expected to return auxiliary values along with the gradients.
86+
Returns:
87+
callable: A function that computes the filtered gradients of `func`. It returns the grad, the new loss scaling, and a boolean indicating whether the gradients are finite (and the aux-value if has_aux is true).
88+
"""
89+
def wrapper(*args, **kwargs):
90+
args_cast = tuple([cast.cast_to_half_precision(x) for x in args])
91+
kwargs_cast = {k: cast.cast_to_half_precision(v) for k, v in kwargs.items()}
92+
93+
func_scaled = loss_scaling.scaled(func, scaling)
94+
95+
dfunc_scaled = eqx.filter_grad(func_scaled, has_aux=has_aux)
96+
97+
if has_aux:
98+
aux, grad = dfunc_scaled(*args_cast, **kwargs_cast)
99+
grads_finite = loss_scaling.all_finite(grad)
100+
loss_scaling_new = scaling.adjust(grads_finite)
101+
grad = loss_scaling_new.unscale(grad)
102+
return aux, loss_scaling_new, grads_finite, grad
103+
else:
104+
grad = dfunc_scaled(*args_cast, **kwargs_cast)
105+
grads_finite = loss_scaling.all_finite(grad)
106+
loss_scaling_new = scaling.adjust(grads_finite)
107+
grad = loss_scaling_new.unscale(grad)
108+
return loss_scaling_new, grads_finite, grad
109+
110+
return wrapper
111+
112+
113+
def filter_value_and_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False) -> PyTree:
114+
"""
115+
Wraps a function to compute its value and gradient with support for mixed precision
116+
and dynamic loss scaling.
117+
Args:
118+
func (Callable): The function for which the value and gradient are to be computed.
119+
scaling (loss_scaling.DynamicLossScaling): An instance of DynamicLossScaling to
120+
handle loss scaling and gradient unscaling.
121+
has_aux (bool, optional): Indicates whether the function `func` returns auxiliary
122+
outputs along with the main value. Defaults to False.
123+
Returns:
124+
Callable: A wrapped function that computes the value, gradient, and additional
125+
information:
126+
- If `has_aux` is True:
127+
((value, aux), loss_scaling_new, grads_finite, grad)
128+
- If `has_aux` is False:
129+
(value, loss_scaling_new, grads_finite, grad)
130+
Where:
131+
- `value`: The computed value of the function.
132+
- `aux`: Auxiliary outputs returned by the function (if `has_aux` is True).
133+
- `loss_scaling_new`: The updated loss scaling object.
134+
- `grads_finite`: A boolean indicating whether all gradients are finite.
135+
- `grad`: The computed gradients, unscaled.
136+
"""
137+
138+
def wrapper(*args, **kwargs):
139+
args_cast = tuple([cast.cast_to_half_precision(x) for x in args])
140+
kwargs_cast = {k: cast.cast_to_half_precision(v) for k, v in kwargs.items()}
141+
142+
func_scaled = loss_scaling.scaled(func, scaling)
143+
144+
dfunc_scaled = eqx.filter_value_and_grad(func_scaled, has_aux=has_aux)
145+
146+
if has_aux:
147+
(value, aux), grad = dfunc_scaled(*args_cast, **kwargs_cast)
148+
grads_finite = loss_scaling.all_finite(grad)
149+
loss_scaling_new = scaling.adjust(grads_finite)
150+
grad = loss_scaling_new.unscale(grad)
151+
value = loss_scaling_new.unscale(value)
152+
return (value, aux), loss_scaling_new, grads_finite, grad
153+
else:
154+
value, grad = dfunc_scaled(*args_cast, **kwargs_cast)
155+
grads_finite = loss_scaling.all_finite(grad)
156+
loss_scaling_new = scaling.adjust(grads_finite)
157+
grad = loss_scaling_new.unscale(grad)
158+
value = loss_scaling_new.unscale(value)
159+
return value, loss_scaling_new, grads_finite, grad
160+
161+
return wrapper
162+
163+
164+
def optimizer_update(model: PyTree, optimizer: optax.GradientTransformation, optimizer_state: PyTree, grads: PyTree, grads_finite: Bool):
165+
166+
# optimizer step
167+
updates, new_optimizer_state = optimizer.update(
168+
grads, optimizer_state, eqx.filter(model, eqx.is_array)
169+
)
170+
new_model = eqx.apply_updates(model, updates)
171+
172+
# only apply updates to the model and optimizer state if gradients are finite
173+
model = select_tree(grads_finite, new_model, model)
174+
optimizer_state = select_tree(grads_finite, new_optimizer_state, optimizer_state)
175+
176+
return model, optimizer_state

0 commit comments

Comments
 (0)