Skip to content

Commit b8645b0

Browse files
committed
simplify custom vector test with assertTreeEqual
1 parent 0a1ccc8 commit b8645b0

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

tree_math/_src/vector_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,10 @@ def tree_unflatten(cls, _, args):
166166

167167
v1 = CustomVector(1, 2.0)
168168
v2 = v1 + 3
169-
assert isinstance(v2, CustomVector)
170-
assert v2.a == 4
171-
assert np.isclose(v2.b, 5.0)
169+
self.assertTreeEqual(v2, CustomVector(4, 5.0), check_dtypes=True)
172170

173171
v3 = v2 + v1
174-
assert isinstance(v3, CustomVector)
175-
assert v3.a == 5
176-
assert np.isclose(v3.b, 7.0)
172+
self.assertTreeEqual(v3, CustomVector(5, 7.0), check_dtypes=True)
177173

178174

179175

0 commit comments

Comments
 (0)