@@ -53,14 +53,14 @@ def g(*args3):
53
53
54
54
def broadcasting_map (func , * args ):
55
55
"""Like tree_map, but scalar arguments are broadcast to all leaves."""
56
- static_argnums = [i for i , x in enumerate (args ) if not isinstance (x , Vector )]
56
+ static_argnums = [i for i , x in enumerate (args ) if not isinstance (x , VectorBase )]
57
57
func2 , vector_args = _argnums_partial (func , args , static_argnums )
58
58
for arg in args :
59
- if not isinstance (arg , Vector ):
59
+ if not isinstance (arg , VectorBase ):
60
60
shape = jnp .shape (arg )
61
61
if shape :
62
62
raise TypeError (
63
- f"non-tree_math.Vector argument is not a scalar: { arg !r} "
63
+ f"non-tree_math.VectorBase argument is not a scalar: { arg !r} "
64
64
)
65
65
if not vector_args :
66
66
return func2 () # result is a scalar
@@ -112,8 +112,8 @@ def dot(left, right, *, precision="highest"):
112
112
Returns:
113
113
Resulting dot product (scalar).
114
114
"""
115
- if not isinstance (left , Vector ) or not isinstance (right , Vector ):
116
- raise TypeError ("matmul arguments must both be tree_math.Vector objects" )
115
+ if not isinstance (left , VectorBase ) or not isinstance (right , VectorBase ):
116
+ raise TypeError ("matmul arguments must both be tree_math.VectorBase objects" )
117
117
118
118
def _vector_dot (a , b ):
119
119
return jnp .dot (jnp .ravel (a ), jnp .ravel (b ), precision = precision )
@@ -206,7 +206,7 @@ def max(self):
206
206
return jnp .asarray (list (parts )).max ()
207
207
208
208
@tree_util .register_pytree_node_class
209
- class Vector :
209
+ class Vector ( VectorBase ) :
210
210
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
211
211
212
212
def __init__ (self , tree ):
0 commit comments