@@ -4,9 +4,84 @@ tree-math makes it easy to implement numerical algorithms that work on
4
4
[ JAX pytrees] ( https://jax.readthedocs.io/en/latest/pytrees.html ) , such as
5
5
iterative methods for optimization and equation solving. It does so by providing
6
6
a wrapper class ` tree_math.Vector ` that defines array operations such as
7
- infix arithmetic and dot-products on pytrees.
7
+ infix arithmetic and dot-products on pytrees as if they were vectors .
8
8
9
- For example, here's how we could write the preconditioned conjugate gradient
9
+ ## Why tree-math
10
+
11
+ In a library like SciPy, numerical algorithms are typically written to handle
12
+ fixed-rank arrays, e.g., [ ` scipy.integrate.solve_ivp ` ] ( https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html )
13
+ requires inputs of shape ` (n,) ` . This is convenient for implementors of
14
+ numerical methods, but not for users, because 1d arrays are typically not the
15
+ best way to keep track of state for non-trivial functions (e.g., neural networks
16
+ or PDE solvers).
17
+
18
+ tree-math provides an alternative to flattening and unflattening these more
19
+ complex data structures ("pytrees") for use in numerical algorithms. Instead,
20
+ the numerical algorithm itself can be written in way to handle arbitrary
21
+ collections of arrays stored in pytrees. This avoids unnecessary memory copies,
22
+ and gives the user more control over the memory layouts used in computation.
23
+ In practice, this can often makes a big difference for computational efficiency
24
+ as well, which is why support for flexible data structures is so prevalent
25
+ inside libraries that use JAX.
26
+
27
+ ## Installation
28
+
29
+ tree-math is implemented in pure Python, and only depends upon JAX.
30
+
31
+ You can install it from PyPI: ` pip install tree-math ` .
32
+
33
+ ## User guide
34
+
35
+ tree-math is simple to use. Just pass arbitrary pytree objects into
36
+ ` tree_math.Vector ` to create an a object that arithmetic as if all leaves of
37
+ the pytree were flattened and concatenated together:
38
+ ```
39
+ >>> import tree_math as tm
40
+ >>> import jax.numpy as jnp
41
+ >>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
42
+ >>> v
43
+ tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
44
+ >>> v + 1
45
+ tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
46
+ >>> v.sum()
47
+ DeviceArray(6, dtype=int32)
48
+ ```
49
+
50
+ You can also find a few functions defined on vectors in ` tree_math.numpy ` , which
51
+ implements a very restricted subset of ` jax.numpy ` . If you're interested in more
52
+ functionality, please open an issue to discuss before sending a pull request.
53
+ (In the long term, this separate module might disappear if we can support
54
+ ` Vector ` objects directly inside ` jax.numpy ` .)
55
+
56
+ Vector objects are pytrees themselves, which means the are compatible with JAX
57
+ transformations like ` jit ` , ` vmap ` and ` grad ` , and control flow like
58
+ ` while_loop ` and ` cond ` .
59
+
60
+ When you're done manipulating vectors, you can pull out the underlying pytrees
61
+ from the ` .tree ` property:
62
+ ```
63
+ >>> v.tree
64
+ {'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}
65
+ ```
66
+
67
+ As an alternative to manipulating ` Vector ` objects directly, you can also use
68
+ the functional transformations ` wrap ` and ` unwrap ` (see the "Example usage"
69
+ below).
70
+
71
+ One important difference between ` tree_math ` and ` jax.numpy ` is that dot
72
+ products in ` tree_math ` default to full precision on all platforms, rather
73
+ than defaulting to bfloat16 precision on TPUs. This is useful for writing most
74
+ numerical algorithms, and will likely be JAX's default behavior
75
+ [ in the future] ( https://github.com/google/jax/pull/7859 ) .
76
+
77
+ In the near-term, we also plan to add a ` Matrix ` class that will make it
78
+ possible to use tree-math for numerical algorithms such as
79
+ [ L-BFGS] ( https://en.wikipedia.org/wiki/Limited-memory_BFGS ) which use matrices
80
+ to represent stacks of vectors.
81
+
82
+ ## Example usage
83
+
84
+ Here is how we could write the preconditioned conjugate gradient
10
85
method. Notice how similar the implementation is to the [ pseudocode from
11
86
Wikipedia] ( https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method ) ,
12
87
unlike the [ implementation in JAX] ( https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121 ) :
@@ -49,17 +124,3 @@ def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
49
124
x_final, * _ = lax.while_loop(cond_fun, body_fun, initial_value)
50
125
return x_final
51
126
```
52
-
53
- For most operations, we recommend working directly with ` Vector ` objects or the
54
- ` wrap ` and ` unwrap ` helper functions. You can also find a few functions defined
55
- on vectors in ` tree_math.numpy ` , which implements a very restricted subset
56
- of ` jax.numpy ` . (In the long term, this separate module might dissappear if we
57
- can support ` Vector ` objects directly inside ` jax.numpy ` ).
58
-
59
- One important different between ` tree_math ` and ` jax.numpy ` is that dot
60
- products in ` tree_math ` default to full precision on all platforms, rather
61
- than defaulting to bfloat16 precision on TPUs. This is useful for writing most
62
- numerical algorithms, and will likely be JAX's default behavior
63
- [ in the future] ( https://github.com/google/jax/pull/7859 ) .
64
-
65
- TODO(shoyer): add a full tutorial!
0 commit comments