Skip to content

Commit 16ab724

Browse files
committed
initial refactor
1 parent 83f4360 commit 16ab724

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

tree_math/_src/vector.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def broadcasting_map(func, *args):
6464
)
6565
if not vector_args:
6666
return func2() # result is a scalar
67-
_flatten_together(*[arg.tree for arg in vector_args]) # check shapes
67+
_flatten_together(*[arg for arg in vector_args]) # check shapes
6868
return tree_util.tree_map(func2, *vector_args)
6969

7070

@@ -118,37 +118,16 @@ def dot(left, right, *, precision="highest"):
118118
def _vector_dot(a, b):
119119
return jnp.dot(jnp.ravel(a), jnp.ravel(b), precision=precision)
120120

121-
(left_values, right_values), _ = _flatten_together(left.tree, right.tree)
121+
(left_values, right_values), _ = _flatten_together(left, right)
122122
parts = map(_vector_dot, left_values, right_values)
123123
return functools.reduce(operator.add, parts)
124124

125-
126-
@tree_util.register_pytree_node_class
127-
class Vector:
125+
class VectorBase:
128126
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
129127

130-
def __init__(self, tree):
131-
self._tree = tree
132-
133-
@property
134-
def tree(self):
135-
return self._tree
136-
137-
# TODO(shoyer): consider casting to a common dtype?
138-
139-
def __repr__(self):
140-
return f"tree_math.Vector({self._tree!r})"
141-
142-
def tree_flatten(self):
143-
return (self.tree,), None
144-
145-
@classmethod
146-
def tree_unflatten(cls, _, args):
147-
return cls(*args)
148-
149128
@property
150129
def size(self):
151-
values = tree_util.tree_leaves(self.tree)
130+
values = tree_util.tree_leaves(self)
152131
return sum(jnp.size(value) for value in values)
153132

154133
def __len__(self):
@@ -164,7 +143,7 @@ def ndim(self):
164143

165144
@property
166145
def dtype(self):
167-
values = tree_util.tree_leaves(self.tree)
146+
values = tree_util.tree_leaves(self)
168147
return jnp.result_type(*values)
169148

170149
# comparison
@@ -225,3 +204,28 @@ def min(self):
225204
def max(self):
226205
parts = map(jnp.max, tree_util.tree_leaves(self))
227206
return jnp.asarray(list(parts)).max()
207+
208+
@tree_util.register_pytree_node_class
209+
class Vector:
210+
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
211+
212+
def __init__(self, tree):
213+
self._tree = tree
214+
215+
@property
216+
def tree(self):
217+
return self._tree
218+
219+
# TODO(shoyer): consider casting to a common dtype?
220+
221+
def __repr__(self):
222+
return f"tree_math.Vector({self._tree!r})"
223+
224+
def tree_flatten(self):
225+
return (self._tree,), None
226+
227+
@classmethod
228+
def tree_unflatten(cls, _, args):
229+
return cls(*args)
230+
231+

0 commit comments

Comments
 (0)