Skip to content

Commit f192af0

Browse files
committed
rename VectorBase to VectorMixin
1 parent b8645b0 commit f192af0

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

tree_math/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
wrap,
2121
unwrap,
2222
)
23-
from tree_math._src.vector import Vector, VectorBase
23+
from tree_math._src.vector import Vector, VectorMixin
2424
import tree_math.numpy
2525

2626
__version__ = '0.1.0'

tree_math/_src/numpy_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_where_all_scalars(self):
4949
actual = tnp.where(True, 1, 2)
5050
self.assertTreeEqual(actual, expected, check_dtypes=False)
5151
with self.assertRaisesRegex(
52-
TypeError, "non-tree_math.VectorBase argument is not a scalar",
52+
TypeError, "non-tree_math.VectorMixin argument is not a scalar",
5353
):
5454
tnp.where(True, jnp.array([1, 2]), 3)
5555

tree_math/_src/vector.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def g(*args3):
5353

5454
def broadcasting_map(func, *args):
5555
"""Like tree_map, but scalar arguments are broadcast to all leaves."""
56-
static_argnums = [i for i, x in enumerate(args) if not isinstance(x, VectorBase)]
56+
static_argnums = [i for i, x in enumerate(args) if not isinstance(x, VectorMixin)]
5757
func2, vector_args = _argnums_partial(func, args, static_argnums)
5858
for arg in args:
59-
if not isinstance(arg, VectorBase):
59+
if not isinstance(arg, VectorMixin):
6060
shape = jnp.shape(arg)
6161
if shape:
6262
raise TypeError(
63-
f"non-tree_math.VectorBase argument is not a scalar: {arg!r}"
63+
f"non-tree_math.VectorMixin argument is not a scalar: {arg!r}"
6464
)
6565
if not vector_args:
6666
return func2() # result is a scalar
@@ -112,8 +112,8 @@ def dot(left, right, *, precision="highest"):
112112
Returns:
113113
Resulting dot product (scalar).
114114
"""
115-
if not isinstance(left, VectorBase) or not isinstance(right, VectorBase):
116-
raise TypeError("matmul arguments must both be tree_math.VectorBase objects")
115+
if not isinstance(left, VectorMixin) or not isinstance(right, VectorMixin):
116+
raise TypeError("matmul arguments must both be tree_math.VectorMixin objects")
117117

118118
def _vector_dot(a, b):
119119
return jnp.dot(jnp.ravel(a), jnp.ravel(b), precision=precision)
@@ -122,7 +122,7 @@ def _vector_dot(a, b):
122122
parts = map(_vector_dot, left_values, right_values)
123123
return functools.reduce(operator.add, parts)
124124

125-
class VectorBase:
125+
class VectorMixin:
126126
"""A mixin class that adds a 1D vector-like behaviour to any custom pytree class."""
127127

128128
@property
@@ -206,7 +206,7 @@ def max(self):
206206
return jnp.asarray(list(parts)).max()
207207

208208
@tree_util.register_pytree_node_class
209-
class Vector(VectorBase):
209+
class Vector(VectorMixin):
210210
"""A wrapper for treating an arbitrary pytree as a 1D vector."""
211211

212212
def __init__(self, tree):

tree_math/_src/vector_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_arithmetic_with_scalar(self):
5858
self.assertTreeEqual(vector + 1, expected, check_dtypes=True)
5959
self.assertTreeEqual(1 + vector, expected, check_dtypes=True)
6060
with self.assertRaisesRegex(
61-
TypeError, "non-tree_math.VectorBase argument is not a scalar",
61+
TypeError, "non-tree_math.VectorMixin argument is not a scalar",
6262
):
6363
vector + jnp.ones((3,)) # pylint: disable=expression-not-assigned
6464

@@ -123,7 +123,7 @@ def test_matmul(self):
123123
self.assertAllClose(actual, expected)
124124

125125
with self.assertRaisesRegex(
126-
TypeError, "matmul arguments must both be tree_math.VectorBase objects",
126+
TypeError, "matmul arguments must both be tree_math.VectorMixin objects",
127127
):
128128
vector1 @ jnp.ones((7,)) # pylint: disable=expression-not-assigned
129129

@@ -151,7 +151,7 @@ def test_sum_mean_min_max(self):
151151
def test_custom_class(self):
152152

153153
@tree_util.register_pytree_node_class
154-
class CustomVector(tm.VectorBase):
154+
class CustomVector(tm.VectorMixin):
155155

156156
def __init__(self, a: int, b: float):
157157
self.a = a

0 commit comments

Comments
 (0)