Skip to content

Commit c3f1b70

Browse files
author
Alexander
committed
added fp8 code and tests
1 parent 4ea6711 commit c3f1b70

File tree

4 files changed

+111
-72
lines changed

4 files changed

+111
-72
lines changed

mpx/_dtypes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ def set_forward_backward_precision(forward_datatype, backward_datatype):
4242
global EXPERIMENTAL_ACTIVATED
4343
global FORWARD_PRECISION_DATATYPE
4444
global BACKWARD_PRECISION_DATATYPE
45-
logging.warning("Setting forward precision is an experimental feature and may lead to unexpected behavior.")
45+
logging.warning("Setting forward backward precision is an experimental feature and may lead to unexpected behavior.")
4646
EXPERIMENTAL_ACTIVATED = True
4747
FORWARD_PRECISION_DATATYPE = forward_datatype
4848
BACKWARD_PRECISION_DATATYPE = backward_datatype
49+
assert backward_datatype == jnp.float32, "Currently only float32 is supported as backward datatype."
4950

5051
def forward_datatype():
5152
assert EXPERIMENTAL_ACTIVATED, "Experimental features not activated. Call set_forward_backward_precision first."

mpx/experimental/__init__.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,6 @@
1-
"""
2-
Mixed Precision for JAX - A library for mixed precision training in JAX
3-
"""
4-
5-
__version__ = "0.1.7"
6-
7-
from ._cast import (
8-
cast_tree,
9-
cast_to_float32,
10-
cast_to_float16,
11-
cast_to_bfloat16,
12-
cast_to_full_precision,
13-
cast_to_half_precision,
14-
force_full_precision,
15-
cast_function,
16-
)
17-
from ._dtypes import half_precision_datatype, set_half_precision_datatype, HALF_PRECISION_DATATYPE, set_forward_backward_precision, forward_datatype, backward_datatype # , FLOAT16_MAX, BFLOAT16_MAX
18-
from ._loss_scaling import DynamicLossScaling, all_finite, scaled
19-
from ._grad_tools import select_tree, filter_grad, filter_value_and_grad, optimizer_update, calculate_scaled_grad
20-
1+
from ._cast import cast_function_fwd_bwd
212

223
__all__ = [
234
# Cast functions
24-
'cast_tree',
25-
'cast_to_float32',
26-
'cast_to_float16',
27-
'cast_to_bfloat16',
28-
'cast_to_full_precision',
29-
'cast_to_half_precision',
30-
'force_full_precision',
31-
'cast_function',
5+
'cast_function_fwd_bwd',
326
]

mpx/experimental/_cast.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,43 +13,44 @@
1313
from jaxtyping import Array, Float, Int, PyTree, PRNGKeyArray, ArrayLike
1414

1515
from .._dtypes import forward_datatype, backward_datatype
16-
from .._cast import cast_tree
16+
from .._cast import cast_function
1717

1818

1919
def max_val(dtype):
2020
return (jnp.finfo(dtype).max).astype(jnp.float32)
2121

22-
@partial(jax.custom_vjp, nondiff_argnames=("dtype8", 'dimension_numbers', 'precision', 'preferred_element_type', 'out_sharding'))
23-
def quantized_multiplication(a: ArrayLike, b: ArrayLike, dtype8, dimension_numbers, precision, preferred_element_type, out_sharding):
22+
@partial(jax.custom_vjp, nondiff_argnames=('dimension_numbers', 'precision', 'preferred_element_type', 'out_sharding'))
23+
def quantized_multiplication(a: ArrayLike, b: ArrayLike, dimension_numbers, precision, preferred_element_type, out_sharding):
2424
a_max = jnp.max(jnp.abs(a))
2525
b_max = jnp.max(jnp.abs(b))
26-
max_dtype = max_val(dtype8)
26+
fwd_dtype = forward_datatype()
27+
max_dtype = max_val(fwd_dtype)
2728
scaling_a = max_dtype / (a_max + 1e-8)
2829
scaling_b = max_dtype / (b_max + 1e-8)
2930

30-
a_q = (a * scaling_a).astype(dtype8)
31-
b_q = (b * scaling_b).astype(dtype8)
31+
a_q = (a * scaling_a).astype(fwd_dtype)
32+
b_q = (b * scaling_b).astype(fwd_dtype)
3233

3334
result_q = jax.lax.dot_general_p.bind(a_q, b_q, dimension_numbers=dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, out_sharding=out_sharding)
34-
35-
result = (result_q.astype(jnp.float32)) / (scaling_a * scaling_b)
35+
result = (result_q.astype(backward_datatype())) / (scaling_a * scaling_b)
3636
return result
3737

3838

39-
def quantized_multiplication_fwd(a: ArrayLike, b: ArrayLike, dtype8, dimension_numbers, precision, preferred_element_type, out_sharding):
39+
def quantized_multiplication_fwd(a: ArrayLike, b: ArrayLike, dimension_numbers, precision, preferred_element_type, out_sharding):
4040
a_max = jnp.max(jnp.abs(a))
4141
b_max = jnp.max(jnp.abs(b))
42-
max_dtype = max_val(dtype8)
42+
fwd_dtype = forward_datatype()
43+
max_dtype = max_val(fwd_dtype)
4344
scaling_a = max_dtype / (a_max + 1e-8)
4445
scaling_b = max_dtype / (b_max + 1e-8)
4546

46-
a_q = (a * scaling_a).astype(dtype8)
47-
b_q = (b * scaling_b).astype(dtype8)
47+
a_q = (a * scaling_a).astype(fwd_dtype)
48+
b_q = (b * scaling_b).astype(fwd_dtype)
4849
# we want to save the quantized versions for the backward pass to save memory
49-
return quantized_multiplication(a, b, dtype8, dimension_numbers, precision, preferred_element_type, out_sharding), (a_q, b_q, scaling_a, scaling_b)
50+
return quantized_multiplication(a, b, dimension_numbers, precision, preferred_element_type, out_sharding), (a_q, b_q, scaling_a, scaling_b)
5051

5152
# f_bwd :: (c, CT b) -> CT a
52-
def quantized_multiplication_bwd(dtype8, dimension_numbers, precision, preferred_element_type, out_sharding, c, dy_dc):
53+
def quantized_multiplication_bwd(dimension_numbers, precision, preferred_element_type, out_sharding, c, dy_dc):
5354
a_q, b_q, scaling_a, scaling_b = c
5455
backward_dtype = backward_datatype()
5556
# backward is performed in fp32 TODO allow to change it.
@@ -62,41 +63,24 @@ def quantized_multiplication_bwd(dtype8, dimension_numbers, precision, preferred
6263

6364
quantized_multiplication.defvjp(quantized_multiplication_fwd, quantized_multiplication_bwd)
6465

66+
6567
@quax.register(jax.lax.dot_general_p)
6668
def _(lhs: ArrayLike, rhs: ArrayLike, **params):
67-
return quantized_multiplication(lhs, rhs, jnp.float8_e4m3, **params)
69+
return quantized_multiplication(lhs, rhs, **params)
6870

6971

70-
71-
def cast_function(func, dtype, return_dtype=None):
72+
def cast_function_fwd_bwd(f: callable) -> callable:
7273
"""
73-
Casts the function to the specified data type.
74+
Casts a function to use the specified forward and backward data types.
75+
Args:
76+
f (callable): The function to be cast.
77+
Returns:
78+
callable: A new function that uses the specified data types for forward and backward passes.
7479
"""
7580

76-
if return_dtype is None:
77-
return_dtype = dtype
78-
79-
def wrapper(*args, **kwargs):
80-
args_cast = []
81-
for arg in args:
82-
args_cast.append(cast_tree(arg, dtype))
83-
args_cast = tuple(args_cast)
84-
85-
kwargs_cast = {}
86-
for key, value in kwargs.items():
87-
kwargs_cast[key] = cast_tree(value, dtype)
88-
89-
results = func(*args_cast, **kwargs_cast)
90-
91-
if type(results) == tuple:
92-
results_converted = []
93-
for r in results:
94-
results_converted.append(cast_tree(r, return_dtype))
95-
return tuple(results_converted)
96-
elif eqx.is_array(results):
97-
return cast_tree(results, return_dtype)
98-
return results
99-
100-
return wrapper
81+
# cast inuts to bwd_dtype. This makes all non multiply operations to be in bwd_dtype
82+
f = cast_function(f, backward_datatype())
10183

84+
f = quax.quaxify(f)
10285

86+
return f

tests/test_fp8.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import unittest
2+
import jax
3+
import jax.numpy as jnp
4+
import equinox as eqx
5+
from jaxtyping import Array, Float, Int, PyTree
6+
import numpy as np
7+
8+
from mpx import set_forward_backward_precision
9+
10+
from mpx.experimental import cast_function_fwd_bwd
11+
12+
13+
class MLP(eqx.Module):
14+
a: Array
15+
b: Array
16+
17+
def __init__(self):
18+
self.a = jnp.ones((10, 10), dtype=jnp.float32)
19+
self.b = jnp.ones(10, dtype=jnp.float32)
20+
21+
def __call__(self, x):
22+
return jax.nn.relu(self.a @ x + self.b)
23+
24+
25+
class TestFP8(unittest.TestCase):
26+
def setUp(self):
27+
# Create some test data
28+
self.array_float32 = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
29+
30+
def test_cast_function_fwd_bwd(self):
31+
# Create test module
32+
module = MLP()
33+
for bwd_dtype in [jnp.float32]:
34+
set_forward_backward_precision(jnp.float8_e5m2, bwd_dtype)
35+
36+
def loss_fn(mdl, inp):
37+
out = mdl(inp)
38+
return jnp.sum(out)
39+
loss_fn_fp8 = cast_function_fwd_bwd(loss_fn)
40+
41+
x = jnp.ones((10,1), dtype=jnp.float32)
42+
43+
# test forward pass
44+
# the output should be the backward datatype as we only cast multiplications to fwd
45+
output = loss_fn_fp8(module, x)
46+
output_original = loss_fn(module, x)
47+
print(output)
48+
print(output_original)
49+
self.assertTrue(np.allclose(output, output_original, atol=1e-4))
50+
self.assertEqual(output.dtype, bwd_dtype)
51+
52+
# test backward pass
53+
grad_fn_fp8 = jax.grad(loss_fn_fp8)
54+
grad_fn = jax.grad(loss_fn)
55+
grads_fp8 = grad_fn_fp8(module, x)
56+
grads = grad_fn(module, x)
57+
58+
# as MLP and x all have the same values, the gradients should be the same
59+
# (for other values, the gradients will differ slightly due to quantization errors)
60+
self.assertTrue(np.allclose(grads_fp8.a, grads.a, atol=1e-4))
61+
self.assertTrue(np.allclose(grads_fp8.b, grads.b, atol=1e-4))
62+
63+
self.assertEqual(grads_fp8.a.dtype, bwd_dtype)
64+
self.assertEqual(grads_fp8.b.dtype, bwd_dtype)
65+
66+
# test now with values where quantization errors are larger
67+
x = jnp.arange(10, dtype=bwd_dtype).reshape((10,1)) + 1.0
68+
output = loss_fn_fp8(module, x)
69+
output_original = loss_fn(module, x)
70+
grads_fp8 = grad_fn_fp8(module, x)
71+
grads = grad_fn(module, x)
72+
73+
self.assertFalse(np.allclose(output, output_original, atol=1e-4))
74+
self.assertFalse(np.allclose(grads_fp8.a, grads.a, atol=1e-4))
75+
# bias is in fp32, so it should be close
76+
self.assertTrue(np.allclose(grads_fp8.b, grads.b, atol=1e-4))
77+
78+
79+
if __name__ == '__main__':
80+
unittest.main()

0 commit comments

Comments
 (0)