Skip to content

Commit e9b4cbd

Browse files
shoyertree-math authors
authored andcommitted
Update documentation for tree_math.struct
Also tag the release 0.2.0 PiperOrigin-RevId: 530328675
1 parent 0ae059e commit e9b4cbd

File tree

3 files changed

+44
-18
lines changed

3 files changed

+44
-18
lines changed

README.md

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,26 +65,29 @@ from the `.tree` property:
6565
```
6666

6767
As an alternative to manipulating `Vector` objects directly, you can also use
68-
the functional transformations `wrap` and `unwrap` (see the "Example usage"
69-
below).
68+
the functional transformations `wrap` and `unwrap` (see the "Writing an
69+
algorithm" below). Or you can create your own `Vector`-like objects from a pytree with
70+
`VectorMixin` or `tree_math.struct` (see "Custom vector classes" below).
7071

7172
One important difference between `tree_math` and `jax.numpy` is that dot
7273
products in `tree_math` default to full precision on all platforms, rather
7374
than defaulting to bfloat16 precision on TPUs. This is useful for writing most
7475
numerical algorithms, and will likely be JAX's default behavior
7576
[in the future](https://github.com/google/jax/pull/7859).
7677

77-
In the near-term, we also plan to add a `Matrix` class that will make it
78-
possible to use tree-math for numerical algorithms such as
78+
It would be nice to have a `Matrix` class to make it possible to use tree-math
79+
for numerical algorithms such as
7980
[L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) which use matrices
80-
to represent stacks of vectors.
81+
to represent stacks of vectors. If you're interesting in contributing this
82+
feature, please comment on [this GitHub issue](https://github.com/google/tree-math/issues/6).
8183

82-
## Example usage
84+
### Writing an algorithm
8385

8486
Here is how we could write the preconditioned conjugate gradient
8587
method. Notice how similar the implementation is to the [pseudocode from
8688
Wikipedia](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method),
87-
unlike the [implementation in JAX](https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121):
89+
unlike the [implementation in JAX](https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121).
90+
Both versions support arbitrary pytrees as input:
8891

8992
```python
9093
import functools
@@ -124,3 +127,28 @@ def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
124127
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
125128
return x_final
126129
```
130+
131+
### Custom vector classes
132+
133+
You can also make your own classes directly support math like `Vector`. To do
134+
so, either inherit from `tree_math.VectorMixin` on your pytree class, or use
135+
`tree_math.struct` (similar to
136+
[`flax.struct`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html))
137+
to create pytree and tree math supporting
138+
[dataclass](https://docs.python.org/3/library/dataclasses.html):
139+
140+
```python
141+
import jax
142+
import tree_math
143+
144+
@tree_math.struct
145+
class Point:
146+
x: float | jax.Array
147+
y: float | jax.Array
148+
149+
a = Point(0.0, 1.0)
150+
b = Point(2.0, 3.0)
151+
152+
a + 3 * b # Point(6.0, 10.0)
153+
jax.grad(lambda x, y: x @ y)(a, b) # Point(2.0, 3.0)
154+
```

tree_math/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@
2424
from tree_math._src.vector import Vector, VectorMixin
2525
import tree_math.numpy
2626

27-
__version__ = '0.1.0'
27+
__version__ = '0.2.0'

tree_math/_src/structs.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,20 @@ class is also a valid pytree, making it compatible with JAX function
2727
transformations such as `jit` and `grad`.
2828
2929
Example usage:
30-
3130
```
32-
@struct
31+
import jax
32+
import tree_math
33+
34+
@tree_math.struct
3335
class Point:
3436
x: float
3537
y: float
3638
37-
a = Point(0., 1.)
38-
b = Point(1., 1.)
39-
40-
a + 3 * b # Point(3., 4.)
41-
42-
def norm_squared(pt):
43-
return pt.x**2 + pt.y**2
39+
a = Point(0.0, 1.0)
40+
b = Point(2.0, 3.0)
4441
45-
jax.jit(jax.grad(norm))(b) # Point(2., 2.)
42+
a + 3 * b # Point(6.0, 10.0)
43+
jax.grad(lambda x, y: x @ y)(a, b) # Point(2.0, 3.0)
4644
```
4745
4846
Args:

0 commit comments

Comments
 (0)