@@ -64,7 +64,7 @@ def broadcasting_map(func, *args):
64
64
)
65
65
if not vector_args :
66
66
return func2 () # result is a scalar
67
- _flatten_together (* [arg . tree for arg in vector_args ]) # check shapes
67
+ _flatten_together (* [arg for arg in vector_args ]) # check shapes
68
68
return tree_util .tree_map (func2 , * vector_args )
69
69
70
70
@@ -118,37 +118,16 @@ def dot(left, right, *, precision="highest"):
118
118
def _vector_dot (a , b ):
119
119
return jnp .dot (jnp .ravel (a ), jnp .ravel (b ), precision = precision )
120
120
121
- (left_values , right_values ), _ = _flatten_together (left . tree , right . tree )
121
+ (left_values , right_values ), _ = _flatten_together (left , right )
122
122
parts = map (_vector_dot , left_values , right_values )
123
123
return functools .reduce (operator .add , parts )
124
124
125
-
126
- @tree_util .register_pytree_node_class
127
- class Vector :
125
+ class VectorBase :
128
126
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
129
127
130
- def __init__ (self , tree ):
131
- self ._tree = tree
132
-
133
- @property
134
- def tree (self ):
135
- return self ._tree
136
-
137
- # TODO(shoyer): consider casting to a common dtype?
138
-
139
- def __repr__ (self ):
140
- return f"tree_math.Vector({ self ._tree !r} )"
141
-
142
- def tree_flatten (self ):
143
- return (self .tree ,), None
144
-
145
- @classmethod
146
- def tree_unflatten (cls , _ , args ):
147
- return cls (* args )
148
-
149
128
@property
150
129
def size (self ):
151
- values = tree_util .tree_leaves (self . tree )
130
+ values = tree_util .tree_leaves (self )
152
131
return sum (jnp .size (value ) for value in values )
153
132
154
133
def __len__ (self ):
@@ -164,7 +143,7 @@ def ndim(self):
164
143
165
144
@property
166
145
def dtype (self ):
167
- values = tree_util .tree_leaves (self . tree )
146
+ values = tree_util .tree_leaves (self )
168
147
return jnp .result_type (* values )
169
148
170
149
# comparison
@@ -225,3 +204,28 @@ def min(self):
225
204
def max (self ):
226
205
parts = map (jnp .max , tree_util .tree_leaves (self ))
227
206
return jnp .asarray (list (parts )).max ()
207
+
208
+ @tree_util .register_pytree_node_class
209
+ class Vector :
210
+ """A wrapper for treating an arbitrary pytree as a 1D vector."""
211
+
212
+ def __init__ (self , tree ):
213
+ self ._tree = tree
214
+
215
+ @property
216
+ def tree (self ):
217
+ return self ._tree
218
+
219
+ # TODO(shoyer): consider casting to a common dtype?
220
+
221
+ def __repr__ (self ):
222
+ return f"tree_math.Vector({ self ._tree !r} )"
223
+
224
+ def tree_flatten (self ):
225
+ return (self ._tree ,), None
226
+
227
+ @classmethod
228
+ def tree_unflatten (cls , _ , args ):
229
+ return cls (* args )
230
+
231
+
0 commit comments