File tree Expand file tree Collapse file tree 1 file changed +28
-0
lines changed Expand file tree Collapse file tree 1 file changed +28
-0
lines changed Original file line number Diff line number Diff line change @@ -148,6 +148,34 @@ def test_sum_mean_min_max(self):
148
148
self .assertTreeEqual (vector .min (), 1 , check_dtypes = False )
149
149
self .assertTreeEqual (vector .max (), 4 , check_dtypes = False )
150
150
151
+ def test_custom_class (self ):
152
+
153
+ @tree_util .register_pytree_node_class
154
+ class CustomVector (tm .VectorBase ):
155
+
156
+ def __init__ (self , a : int , b : float ):
157
+ self .a = a
158
+ self .b = b
159
+
160
+ def tree_flatten (self ):
161
+ return (self .a , self .b ), None
162
+
163
+ @classmethod
164
+ def tree_unflatten (cls , _ , args ):
165
+ return cls (* args )
166
+
167
+ v1 = CustomVector (1 , 2.0 )
168
+ v2 = v1 + 3
169
+ assert isinstance (v2 , CustomVector )
170
+ assert v2 .a == 4
171
+ assert np .isclose (v2 .b , 5.0 )
172
+
173
+ v3 = v2 + v1
174
+ assert isinstance (v3 , CustomVector )
175
+ assert v3 .a == 5
176
+ assert np .isclose (v3 .b , 7.0 )
177
+
178
+
151
179
152
180
if __name__ == "__main__" :
153
181
absltest .main ()
You can’t perform that action at this time.
0 commit comments