Skip to content

Commit 3283ca8

Browse files
committed
refactor into VectorBase
1 parent 16ab724 commit 3283ca8

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tree_math/_src/vector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def g(*args3):
5353

5454
def broadcasting_map(func, *args):
5555
"""Like tree_map, but scalar arguments are broadcast to all leaves."""
56-
static_argnums = [i for i, x in enumerate(args) if not isinstance(x, Vector)]
56+
static_argnums = [i for i, x in enumerate(args) if not isinstance(x, VectorBase)]
5757
func2, vector_args = _argnums_partial(func, args, static_argnums)
5858
for arg in args:
59-
if not isinstance(arg, Vector):
59+
if not isinstance(arg, VectorBase):
6060
shape = jnp.shape(arg)
6161
if shape:
6262
raise TypeError(
63-
f"non-tree_math.Vector argument is not a scalar: {arg!r}"
63+
f"non-tree_math.VectorBase argument is not a scalar: {arg!r}"
6464
)
6565
if not vector_args:
6666
return func2() # result is a scalar
@@ -112,8 +112,8 @@ def dot(left, right, *, precision="highest"):
112112
Returns:
113113
Resulting dot product (scalar).
114114
"""
115-
if not isinstance(left, Vector) or not isinstance(right, Vector):
116-
raise TypeError("matmul arguments must both be tree_math.Vector objects")
115+
if not isinstance(left, VectorBase) or not isinstance(right, VectorBase):
116+
raise TypeError("matmul arguments must both be tree_math.VectorBase objects")
117117

118118
def _vector_dot(a, b):
119119
return jnp.dot(jnp.ravel(a), jnp.ravel(b), precision=precision)
@@ -206,7 +206,7 @@ def max(self):
206206
return jnp.asarray(list(parts)).max()
207207

208208
@tree_util.register_pytree_node_class
209-
class Vector:
209+
class Vector(VectorBase):
210210
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
211211

212212
def __init__(self, tree):

0 commit comments

Comments
 (0)