Skip to content

Commit 86834ee

Browse files
tree-math authorsshoyer
authored andcommitted
Internal change
PiperOrigin-RevId: 418683539
1 parent 82ebc73 commit 86834ee

File tree

5 files changed

+13
-18
lines changed

5 files changed

+13
-18
lines changed

README.md

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,20 @@ unlike the [implementation in JAX](https://github.com/google/jax/blob/b5aea7bc2d
1414
```python
1515
import functools
1616
from jax import lax
17-
import tree_math
17+
import tree_math as tm
1818
import tree_math.numpy as tnp
1919

20-
@functools.partial(tree_math.wrap, vector_argnames=['b', 'x0'])
20+
@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
2121
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
2222
"""jax.scipy.sparse.linalg.cg, written with tree_math."""
23-
A = tree_math.unwrap(A)
24-
M = tree_math.unwrap(M)
23+
A = tm.unwrap(A)
24+
M = tm.unwrap(M)
2525

26-
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
27-
bs = b @ b
28-
atol2 = tnp.maximum(tol**2 * bs, atol**2)
26+
atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)
2927

3028
def cond_fun(value):
3129
x, r, gamma, p, k = value
32-
rs = r @ r
33-
return (rs > atol2) & (k < maxiter)
30+
return (r @ r > atol2) & (k < maxiter)
3431

3532
def body_fun(value):
3633
x, r, gamma, p, k = value

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
]
2323
tests_requires = [
2424
'absl-py',
25+
'jaxlib',
2526
'numpy>=1.17',
2627
'pytest',
2728
]

tree_math/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@
2222
)
2323
from tree_math._src.vector import Vector
2424
import tree_math.numpy
25+
26+
__version__ = '0.1.0'

tree_math/integration_test.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,15 @@ def test_cg(self):
5353

5454
@functools.partial(tm.wrap, vector_argnames=["b", "x0"])
5555
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
56+
"""jax.scipy.sparse.linalg.cg, written with tree_math."""
5657
A = tm.unwrap(A)
5758
M = tm.unwrap(M)
5859

59-
# tolerance handling uses the "non-legacy" behavior of
60-
# scipy.sparse.linalg.cg
61-
bs = b @ b
62-
atol2 = tnp.maximum(tol**2 * bs, atol**2)
63-
64-
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
60+
atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)
6561

6662
def cond_fun(value):
67-
_, r, _, _, k = value
68-
rs = r @ r
69-
return (rs > atol2) & (k < maxiter)
63+
x, r, gamma, p, k = value # pylint: disable=unused-variable
64+
return (r @ r > atol2) & (k < maxiter)
7065

7166
def body_fun(value):
7267
x, r, gamma, p, k = value
File renamed without changes.

0 commit comments

Comments
 (0)