Skip to content

Commit 1bb23a7

Browse files
shoyertree-math authors
authored andcommitted
README updates for tree-math
PiperOrigin-RevId: 418696227
1 parent 86834ee commit 1bb23a7

File tree

2 files changed

+78
-17
lines changed

2 files changed

+78
-17
lines changed

README.md

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,84 @@ tree-math makes it easy to implement numerical algorithms that work on
44
[JAX pytrees](https://jax.readthedocs.io/en/latest/pytrees.html), such as
55
iterative methods for optimization and equation solving. It does so by providing
66
a wrapper class `tree_math.Vector` that defines array operations such as
7-
infix arithmetic and dot-products on pytrees.
7+
infix arithmetic and dot-products on pytrees as if they were vectors.
88

9-
For example, here's how we could write the preconditioned conjugate gradient
9+
## Why tree-math
10+
11+
In a library like SciPy, numerical algorithms are typically written to handle
12+
fixed-rank arrays, e.g., [`scipy.integrate.solve_ivp`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html)
13+
requires inputs of shape `(n,)`. This is convenient for implementors of
14+
numerical methods, but not for users, because 1d arrays are typically not the
15+
best way to keep track of state for non-trivial functions (e.g., neural networks
16+
or PDE solvers).
17+
18+
tree-math provides an alternative to flattening and unflattening these more
19+
complex data structures ("pytrees") for use in numerical algorithms. Instead,
20+
the numerical algorithm itself can be written in way to handle arbitrary
21+
collections of arrays stored in pytrees. This avoids unnecessary memory copies,
22+
and gives the user more control over the memory layouts used in computation.
23+
In practice, this can often makes a big difference for computational efficiency
24+
as well, which is why support for flexible data structures is so prevalent
25+
inside libraries that use JAX.
26+
27+
## Installation
28+
29+
tree-math is implemented in pure Python, and only depends upon JAX.
30+
31+
You can install it from PyPI: `pip install tree-math`.
32+
33+
## User guide
34+
35+
tree-math is simple to use. Just pass arbitrary pytree objects into
36+
`tree_math.Vector` to create an a object that arithmetic as if all leaves of
37+
the pytree were flattened and concatenated together:
38+
```
39+
>>> import tree_math as tm
40+
>>> import jax.numpy as jnp
41+
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
42+
>>> v
43+
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
44+
>>> v + 1
45+
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
46+
>>> v.sum()
47+
DeviceArray(6, dtype=int32)
48+
```
49+
50+
You can also find a few functions defined on vectors in `tree_math.numpy`, which
51+
implements a very restricted subset of `jax.numpy`. If you're interested in more
52+
functionality, please open an issue to discuss before sending a pull request.
53+
(In the long term, this separate module might disappear if we can support
54+
`Vector` objects directly inside `jax.numpy`.)
55+
56+
Vector objects are pytrees themselves, which means the are compatible with JAX
57+
transformations like `jit`, `vmap` and `grad`, and control flow like
58+
`while_loop` and `cond`.
59+
60+
When you're done manipulating vectors, you can pull out the underlying pytrees
61+
from the `.tree` property:
62+
```
63+
>>> v.tree
64+
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}
65+
```
66+
67+
As an alternative to manipulating `Vector` objects directly, you can also use
68+
the functional transformations `wrap` and `unwrap` (see the "Example usage"
69+
below).
70+
71+
One important difference between `tree_math` and `jax.numpy` is that dot
72+
products in `tree_math` default to full precision on all platforms, rather
73+
than defaulting to bfloat16 precision on TPUs. This is useful for writing most
74+
numerical algorithms, and will likely be JAX's default behavior
75+
[in the future](https://github.com/google/jax/pull/7859).
76+
77+
In the near-term, we also plan to add a `Matrix` class that will make it
78+
possible to use tree-math for numerical algorithms such as
79+
[L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) which use matrices
80+
to represent stacks of vectors.
81+
82+
## Example usage
83+
84+
Here is how we could write the preconditioned conjugate gradient
1085
method. Notice how similar the implementation is to the [pseudocode from
1186
Wikipedia](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method),
1287
unlike the [implementation in JAX](https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121):
@@ -49,17 +124,3 @@ def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
49124
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
50125
return x_final
51126
```
52-
53-
For most operations, we recommend working directly with `Vector` objects or the
54-
`wrap` and `unwrap` helper functions. You can also find a few functions defined
55-
on vectors in `tree_math.numpy`, which implements a very restricted subset
56-
of `jax.numpy`. (In the long term, this separate module might dissappear if we
57-
can support `Vector` objects directly inside `jax.numpy`).
58-
59-
One important different between `tree_math` and `jax.numpy` is that dot
60-
products in `tree_math` default to full precision on all platforms, rather
61-
than defaulting to bfloat16 precision on TPUs. This is useful for writing most
62-
numerical algorithms, and will likely be JAX's default behavior
63-
[in the future](https://github.com/google/jax/pull/7859).
64-
65-
TODO(shoyer): add a full tutorial!

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
base_requires = [
2020
'jax',
21-
'numpy',
2221
]
2322
tests_requires = [
2423
'absl-py',
@@ -29,6 +28,7 @@
2928

3029
setuptools.setup(
3130
name='tree-math',
31+
description='Mathematical operations for JAX pytrees',
3232
version='0.1.0 ',
3333
license='Apache 2.0',
3434
author='Google LLC',

0 commit comments

Comments
 (0)