Skip to content

Commit 0a1ccc8

Browse files
committed
add test for a custom vector class
1 parent a321e1f commit 0a1ccc8

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tree_math/_src/vector_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,34 @@ def test_sum_mean_min_max(self):
148148
self.assertTreeEqual(vector.min(), 1, check_dtypes=False)
149149
self.assertTreeEqual(vector.max(), 4, check_dtypes=False)
150150

151+
def test_custom_class(self):
152+
153+
@tree_util.register_pytree_node_class
154+
class CustomVector(tm.VectorBase):
155+
156+
def __init__(self, a: int, b: float):
157+
self.a = a
158+
self.b = b
159+
160+
def tree_flatten(self):
161+
return (self.a, self.b), None
162+
163+
@classmethod
164+
def tree_unflatten(cls, _, args):
165+
return cls(*args)
166+
167+
v1 = CustomVector(1, 2.0)
168+
v2 = v1 + 3
169+
assert isinstance(v2, CustomVector)
170+
assert v2.a == 4
171+
assert np.isclose(v2.b, 5.0)
172+
173+
v3 = v2 + v1
174+
assert isinstance(v3, CustomVector)
175+
assert v3.a == 5
176+
assert np.isclose(v3.b, 7.0)
177+
178+
151179

152180
if __name__ == "__main__":
153181
absltest.main()

0 commit comments

Comments
 (0)