@@ -65,26 +65,29 @@ from the `.tree` property:
65
65
```
66
66
67
67
As an alternative to manipulating ` Vector ` objects directly, you can also use
68
- the functional transformations ` wrap ` and ` unwrap ` (see the "Example usage"
69
- below).
68
+ the functional transformations ` wrap ` and ` unwrap ` (see the "Writing an
69
+ algorithm" below). Or you can create your own ` Vector ` -like objects from a pytree with
70
+ ` VectorMixin ` or ` tree_math.struct ` (see "Custom vector classes" below).
70
71
71
72
One important difference between ` tree_math ` and ` jax.numpy ` is that dot
72
73
products in ` tree_math ` default to full precision on all platforms, rather
73
74
than defaulting to bfloat16 precision on TPUs. This is useful for writing most
74
75
numerical algorithms, and will likely be JAX's default behavior
75
76
[ in the future] ( https://github.com/google/jax/pull/7859 ) .
76
77
77
- In the near-term, we also plan to add a ` Matrix ` class that will make it
78
- possible to use tree-math for numerical algorithms such as
78
+ It would be nice to have a ` Matrix ` class to make it possible to use tree-math
79
+ for numerical algorithms such as
79
80
[ L-BFGS] ( https://en.wikipedia.org/wiki/Limited-memory_BFGS ) which use matrices
80
- to represent stacks of vectors.
81
+ to represent stacks of vectors. If you're interesting in contributing this
82
+ feature, please comment on [ this GitHub issue] ( https://github.com/google/tree-math/issues/6 ) .
81
83
82
- ## Example usage
84
+ ### Writing an algorithm
83
85
84
86
Here is how we could write the preconditioned conjugate gradient
85
87
method. Notice how similar the implementation is to the [ pseudocode from
86
88
Wikipedia] ( https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method ) ,
87
- unlike the [ implementation in JAX] ( https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121 ) :
89
+ unlike the [ implementation in JAX] ( https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121 ) .
90
+ Both versions support arbitrary pytrees as input:
88
91
89
92
``` python
90
93
import functools
@@ -124,3 +127,28 @@ def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
124
127
x_final, * _ = lax.while_loop(cond_fun, body_fun, initial_value)
125
128
return x_final
126
129
```
130
+
131
+ ### Custom vector classes
132
+
133
+ You can also make your own classes directly support math like ` Vector ` . To do
134
+ so, either inherit from ` tree_math.VectorMixin ` on your pytree class, or use
135
+ ` tree_math.struct ` (similar to
136
+ [ ` flax.struct ` ] ( https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html ) )
137
+ to create pytree and tree math supporting
138
+ [ dataclass] ( https://docs.python.org/3/library/dataclasses.html ) :
139
+
140
+ ``` python
141
+ import jax
142
+ import tree_math
143
+
144
+ @tree_math.struct
145
+ class Point :
146
+ x: float | jax.Array
147
+ y: float | jax.Array
148
+
149
+ a = Point(0.0 , 1.0 )
150
+ b = Point(2.0 , 3.0 )
151
+
152
+ a + 3 * b # Point(6.0, 10.0)
153
+ jax.grad(lambda x , y : x @ y)(a, b) # Point(2.0, 3.0)
154
+ ```
0 commit comments