Skip to content

Commit 00c097e

Browse files
author
Alexander
committed
added unittest
1 parent adb2807 commit 00c097e

File tree

11 files changed

+497
-14
lines changed

11 files changed

+497
-14
lines changed

.vscode/settings.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"python.testing.pytestArgs": [
3+
"tests"
4+
],
5+
"python.testing.unittestEnabled": false,
6+
"python.testing.pytestEnabled": true
7+
}

mpfj/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from . import utils
2-
from . import layers
3-
from . import optimizers
4-
51
"""
62
Mixed Precision for JAX (mpfj)
73
@@ -10,4 +6,4 @@
106

117
__version__ = "0.1.0"
128

13-
from .dtypes import set_half_precision_datatype
9+
from .dtypes import half_precision_datatype, set_half_precision_datatype

mpfj/cast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from jaxtyping import Array, Float, Int, PyTree, PRNGKeyArray
3636

37-
from .dtypes import HALF_PRECISION_DATATYPE
37+
from .dtypes import half_precision_datatype
3838

3939
def cast_tree(tree: PyTree, dtype):
4040
"""
@@ -125,7 +125,7 @@ def cast_to_half_precision(x: PyTree) -> PyTree:
125125
PyTree: A new PyTree with all elements cast to the half-precision datatype.
126126
"""
127127
"""Cast to half precision (float16/bfloat16, depending on with what we called set_half_precision_datatype)."""
128-
return cast_tree(x, HALF_PRECISION_DATATYPE)
128+
return cast_tree(x, half_precision_datatype())
129129

130130

131131
def force_full_precision(func, return_dtype=jnp.float16):

mpfj/dtypes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@ def set_half_precision_datatype(datatype):
99
Args:
1010
datatype: The datatype to set as half precision (e.g., jnp.float16).
1111
"""
12-
HALF_PRECISION_DATATYPE = datatype
12+
global HALF_PRECISION_DATATYPE
13+
HALF_PRECISION_DATATYPE = datatype
14+
15+
def half_precision_datatype():
16+
return HALF_PRECISION_DATATYPE

mpfj/grad_tools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333

3434
import optax
3535

36-
import cast as cast
37-
import loss_scaling as loss_scaling
36+
import mpfj.cast as cast
37+
import mpfj.loss_scaling as loss_scaling
3838

3939
from jaxtyping import PyTree, Bool
4040

@@ -167,6 +167,8 @@ def optimizer_update(model: PyTree, optimizer: optax.GradientTransformation, opt
167167
updates, new_optimizer_state = optimizer.update(
168168
grads, optimizer_state, eqx.filter(model, eqx.is_array)
169169
)
170+
print(updates)
171+
print("dddddddd")
170172
new_model = eqx.apply_updates(model, updates)
171173

172174
# only apply updates to the model and optimizer state if gradients are finite

mpfj/loss_scaling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def all_finite(tree: PyTree) -> Array:
5757
leaves = map(jnp.isfinite, leaves)
5858
leaves = map(jnp.all, leaves)
5959
return jnp.stack(list(leaves)).all()
60-
6160

62-
def scaled(func: callable, scaling: DynamicLossScaling):
61+
62+
def scaled(func: callable, scaling: 'DynamicLossScaling'):
6363
def wrapper(*_args, **_kwargs):
6464
value = func(*_args, **_kwargs)
6565
value = scaling.scale(value)
@@ -93,7 +93,7 @@ def unscale(self, tree):
9393
inv_loss_scaling = inv_loss_scaling.astype(jnp.float32) # cast to float32, so the result is float32 (otherwise the whole scaling point would be senseless)
9494
return jax.tree_util.tree_map(lambda x: x * inv_loss_scaling[0], tree)
9595

96-
def adjust(self, grads_finite: jnp.ndarray) -> DynamicLossScaling:
96+
def adjust(self, grads_finite: jnp.ndarray) -> 'DynamicLossScaling':
9797
"""Returns the next state dependent on whether grads are finite."""
9898
assert grads_finite.ndim == 0, "Expected boolean scalar"
9999

@@ -113,7 +113,7 @@ def adjust(self, grads_finite: jnp.ndarray) -> DynamicLossScaling:
113113
jnp.maximum(self.min_loss_scaling, self.loss_scaling / self.factor))
114114

115115
# clip to maximum float16 value.
116-
loss_scaling = jnp.clip(loss_scaling, a_min=self.min_loss_scaling, a_max=(2 - 2**(-10)) * 2**15)
116+
loss_scaling = jnp.clip(loss_scaling, min=self.min_loss_scaling, max=(2 - 2**(-10)) * 2**15)
117117

118118
counter = ((self.counter + 1) % self.period) * grads_finite
119119

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import os
2+
import sys
3+
4+
# Add the project root directory to the Python path
5+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6+
sys.path.insert(0, project_root)

tests/test_cast.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import unittest
2+
import jax
3+
import jax.numpy as jnp
4+
import equinox as eqx
5+
from jaxtyping import Array, Float, Int, PyTree
6+
import numpy as np
7+
8+
from mpfj.cast import (
9+
cast_tree,
10+
cast_to_float32,
11+
cast_to_float16,
12+
cast_to_bfloat16,
13+
cast_to_full_precision,
14+
cast_to_half_precision,
15+
force_full_precision,
16+
)
17+
from mpfj.dtypes import HALF_PRECISION_DATATYPE
18+
19+
20+
class EQXModuleBase(eqx.Module):
21+
a: Array
22+
b: Array
23+
24+
def __init__(self):
25+
self.a = jnp.ones(10, dtype=jnp.float32)
26+
self.b = jnp.ones(10, dtype=jnp.float32)
27+
28+
class LeafClass:
29+
"""If implemented correctly, this class should not be casted"""
30+
a: Array
31+
b: Array
32+
33+
def __init__(self):
34+
self.a = jnp.ones(10, dtype=jnp.float32)
35+
self.b = jnp.ones(10, dtype=jnp.float32)
36+
37+
class EQXModule1(eqx.Module):
38+
a: list[EQXModuleBase]
39+
b: Array
40+
c: LeafClass
41+
42+
def __init__(self):
43+
self.a = [EQXModuleBase() for _ in range(10)]
44+
self.b = jnp.ones(10, dtype=jnp.float32)
45+
self.c = LeafClass()
46+
47+
48+
class TestCastFunctions(unittest.TestCase):
49+
def setUp(self):
50+
# Create some test data
51+
self.array_float32 = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
52+
self.array_float16 = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float16)
53+
self.array_bfloat16 = jnp.array([1.0, 2.0, 3.0], dtype=jnp.bfloat16)
54+
self.nested_dict = {
55+
'a': self.array_float32,
56+
'b': {'c': self.array_float16, 'd': self.array_bfloat16}
57+
}
58+
self.mixed_tree = {
59+
'array': self.array_float32,
60+
'scalar': 42,
61+
'nested': {
62+
'array': self.array_float16,
63+
'none': None
64+
}
65+
}
66+
67+
def test_cast_eqx_module(self):
68+
# Create test module
69+
module = EQXModule1()
70+
71+
# Test casting to float16
72+
result = cast_tree(module, jnp.float16)
73+
# Check that arrays in nested EQXModuleBase instances are cast
74+
for base_module in result.a:
75+
self.assertEqual(base_module.a.dtype, jnp.float16)
76+
self.assertEqual(base_module.b.dtype, jnp.float16)
77+
# Check direct array is cast
78+
self.assertEqual(result.b.dtype, jnp.float16)
79+
# Check that LeafClass arrays are NOT cast since it's not an eqx.Module
80+
self.assertEqual(result.c.a.dtype, jnp.float32)
81+
self.assertEqual(result.c.b.dtype, jnp.float32)
82+
83+
# Test casting to bfloat16
84+
result = cast_tree(module, jnp.bfloat16)
85+
# Check nested modules
86+
for base_module in result.a:
87+
self.assertEqual(base_module.a.dtype, jnp.bfloat16)
88+
self.assertEqual(base_module.b.dtype, jnp.bfloat16)
89+
self.assertEqual(result.b.dtype, jnp.bfloat16)
90+
# LeafClass should remain unchanged
91+
self.assertEqual(result.c.a.dtype, jnp.float32)
92+
self.assertEqual(result.c.b.dtype, jnp.float32)
93+
94+
# Test casting back to float32
95+
result = cast_tree(module, jnp.float32)
96+
for base_module in result.a:
97+
self.assertEqual(base_module.a.dtype, jnp.float32)
98+
self.assertEqual(base_module.b.dtype, jnp.float32)
99+
self.assertEqual(result.b.dtype, jnp.float32)
100+
self.assertEqual(result.c.a.dtype, jnp.float32)
101+
self.assertEqual(result.c.b.dtype, jnp.float32)
102+
103+
def test_cast_tree(self):
104+
# Test casting to float32
105+
result = cast_tree(self.array_float16, jnp.float32)
106+
self.assertEqual(result.dtype, jnp.float32)
107+
108+
# Test casting nested structure
109+
result = cast_tree(self.nested_dict, jnp.float32)
110+
self.assertEqual(result['a'].dtype, jnp.float32)
111+
self.assertEqual(result['b']['c'].dtype, jnp.float32)
112+
self.assertEqual(result['b']['d'].dtype, jnp.float32)
113+
114+
def test_cast_to_float32(self):
115+
result = cast_to_float32(self.array_float16)
116+
self.assertEqual(result.dtype, jnp.float32)
117+
118+
result = cast_to_float32(self.nested_dict)
119+
self.assertEqual(result['a'].dtype, jnp.float32)
120+
self.assertEqual(result['b']['c'].dtype, jnp.float32)
121+
self.assertEqual(result['b']['d'].dtype, jnp.float32)
122+
123+
def test_cast_to_float16(self):
124+
result = cast_to_float16(self.array_float32)
125+
self.assertEqual(result.dtype, jnp.float16)
126+
127+
result = cast_to_float16(self.nested_dict)
128+
self.assertEqual(result['a'].dtype, jnp.float16)
129+
self.assertEqual(result['b']['c'].dtype, jnp.float16)
130+
self.assertEqual(result['b']['d'].dtype, jnp.float16)
131+
132+
def test_cast_to_bfloat16(self):
133+
result = cast_to_bfloat16(self.array_float32)
134+
self.assertEqual(result.dtype, jnp.bfloat16)
135+
136+
result = cast_to_bfloat16(self.nested_dict)
137+
self.assertEqual(result['a'].dtype, jnp.bfloat16)
138+
self.assertEqual(result['b']['c'].dtype, jnp.bfloat16)
139+
self.assertEqual(result['b']['d'].dtype, jnp.bfloat16)
140+
141+
def test_cast_to_full_precision(self):
142+
result = cast_to_full_precision(self.array_float16)
143+
self.assertEqual(result.dtype, jnp.float32)
144+
145+
result = cast_to_full_precision(self.nested_dict)
146+
self.assertEqual(result['a'].dtype, jnp.float32)
147+
self.assertEqual(result['b']['c'].dtype, jnp.float32)
148+
self.assertEqual(result['b']['d'].dtype, jnp.float32)
149+
150+
def test_cast_to_half_precision(self):
151+
result = cast_to_half_precision(self.array_float32)
152+
self.assertEqual(result.dtype, HALF_PRECISION_DATATYPE)
153+
154+
result = cast_to_half_precision(self.nested_dict)
155+
self.assertEqual(result['a'].dtype, HALF_PRECISION_DATATYPE)
156+
self.assertEqual(result['b']['c'].dtype, HALF_PRECISION_DATATYPE)
157+
self.assertEqual(result['b']['d'].dtype, HALF_PRECISION_DATATYPE)
158+
159+
def test_force_full_precision_decorator(self):
160+
@force_full_precision
161+
def test_func(x, y):
162+
return x + y, x * y
163+
164+
# Test with float16 inputs
165+
x = jnp.array([1.0, 2.0], dtype=jnp.float16)
166+
y = jnp.array([3.0, 4.0], dtype=jnp.float16)
167+
168+
result1, result2 = test_func(x, y)
169+
170+
# Check that inputs were converted to float32 during computation
171+
self.assertEqual(result1.dtype, jnp.float16) # Output is cast back to float16
172+
self.assertEqual(result2.dtype, jnp.float16) # Output is cast back to float16
173+
174+
def test_mixed_tree_handling(self):
175+
# Test that non-array elements are preserved
176+
result = cast_to_float32(self.mixed_tree)
177+
self.assertEqual(result['array'].dtype, jnp.float32)
178+
self.assertEqual(result['scalar'], 42)
179+
self.assertEqual(result['nested']['none'], None)
180+
self.assertEqual(result['nested']['array'].dtype, jnp.float32)
181+
182+
def test_empty_structures(self):
183+
# Test with empty structures
184+
empty_dict = {}
185+
result = cast_to_float32(empty_dict)
186+
self.assertEqual(result, {})
187+
188+
empty_list = []
189+
result = cast_to_float32(empty_list)
190+
self.assertEqual(result, [])
191+
192+
if __name__ == '__main__':
193+
unittest.main()

tests/test_dtypes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
import jax.numpy as jnp
3+
from mpfj.dtypes import half_precision_datatype, set_half_precision_datatype
4+
5+
class TestDtypes(unittest.TestCase):
6+
def test_default_half_precision(self):
7+
"""Test that the default half precision datatype is float16"""
8+
self.assertEqual(half_precision_datatype(), jnp.float16)
9+
10+
def test_set_half_precision_datatype(self):
11+
"""Test setting half precision datatype to bfloat16"""
12+
set_half_precision_datatype(jnp.bfloat16)
13+
self.assertEqual(half_precision_datatype(), jnp.bfloat16)
14+
15+
# Reset to default
16+
set_half_precision_datatype(jnp.float16)
17+
self.assertEqual(half_precision_datatype(), jnp.float16)
18+
19+
if __name__ == '__main__':
20+
unittest.main()

0 commit comments

Comments
 (0)