Skip to content

Commit 57aaf8c

Browse files
committed
add notes so hackers can see
1 parent a49ade3 commit 57aaf8c

File tree

2 files changed

+79
-21
lines changed

2 files changed

+79
-21
lines changed

pytensor/optimise/fixed_point.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Callable
12
from functools import partial
23

34
import pytensor
@@ -37,8 +38,8 @@ def newton_solver(x_prev, *args, func, tol):
3738

3839

3940
def fixed_point_solver(
40-
f: callable,
41-
solver: callable,
41+
f: Callable,
42+
solver: Callable,
4243
x0: pt.TensorVariable,
4344
*args: tuple[pt.Variable, ...],
4445
max_iter: int = 1000,

tests/optimise/test_fixed_point.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import functools as ft
2-
31
import jax
42
import numpy as np
53
from numpy.testing import assert_array_almost_equal
@@ -11,21 +9,26 @@
119
jax.config.update("jax_enable_x64", True)
1210

1311

14-
def jax_newton_solver(f, z_init):
15-
def f_root(z):
16-
return f(z) - z
12+
def jax_newton_solver(f, x0):
13+
def f_root(x):
14+
return f(x) - x
15+
16+
def g(x):
17+
return x - jax.numpy.linalg.solve(jax.jacobian(f_root)(x), f_root(x))
18+
19+
return jax_fwd_solver(g, x0)
1720

18-
def g(z):
19-
return z - jax.numpy.linalg.solve(jax.jacobian(f_root)(z), f_root(z))
2021

21-
return jax_fwd_solver(g, z_init)
22+
def jax_fwd_solver(f, x0, tol=1e-5):
23+
x_prev, x = x0, f(x0)
24+
while jax.numpy.linalg.norm(x_prev - x) > tol:
25+
x_prev, x = x, f(x)
26+
return x
2227

2328

24-
def jax_fwd_solver(f, z_init, tol=1e-5):
25-
z_prev, z = z_init, f(z_init)
26-
while jax.numpy.linalg.norm(z_prev - z) > tol:
27-
z_prev, z = z, f(z)
28-
return z
29+
def jax_fixed_point_solver(solver, f, params, x0, **solver_kwargs):
30+
x_star = solver(lambda x: f(x, *params), x0=x0, **solver_kwargs)
31+
return x_star
2932

3033

3134
def test_fixed_point_forward():
@@ -34,7 +37,7 @@ def test_fixed_point_forward():
3437
def g(x, W, b):
3538
return pt.tanh(pt.dot(W, x) + b)
3639

37-
def _jax_g(x, W, b):
40+
def jax_g(x, W, b):
3841
return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
3942

4043
ndim = 10
@@ -43,9 +46,13 @@ def _jax_g(x, W, b):
4346

4447
W, b = np.asarray(W), np.asarray(b)
4548

46-
jax_g = ft.partial(_jax_g, W=W, b=b)
49+
jax_solution = jax_fixed_point_solver(
50+
jax_fwd_solver,
51+
jax_g,
52+
(W, b),
53+
x0=jax.numpy.zeros_like(b),
54+
)
4755

48-
jax_solution = jax_fwd_solver(jax_g, jax.numpy.zeros_like(b))
4956
pytensor_solution, _ = fixed_point_solver(
5057
g,
5158
fwd_solver,
@@ -60,7 +67,7 @@ def test_fixed_point_newton():
6067
def g(x, W, b):
6168
return pt.tanh(pt.dot(W, x) + b)
6269

63-
def _jax_g(x, W, b):
70+
def jax_g(x, W, b):
6471
return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
6572

6673
ndim = 10
@@ -69,9 +76,13 @@ def _jax_g(x, W, b):
6976

7077
W, b = np.asarray(W), np.asarray(b)
7178

72-
jax_g = ft.partial(_jax_g, W=W, b=b)
79+
jax_solution = jax_fixed_point_solver(
80+
jax_newton_solver,
81+
jax_g,
82+
(W, b),
83+
x0=jax.numpy.zeros_like(b),
84+
)
7385

74-
jax_solution = jax_newton_solver(jax_g, jax.numpy.zeros_like(b))
7586
pytensor_solution, _ = fixed_point_solver(
7687
g,
7788
newton_solver,
@@ -86,3 +97,49 @@ def _jax_g(x, W, b):
8697
# and adjoint implicit function theorem rewritten grad
8798
# see the [notes](https://theorashid.github.io/notes/fixed-point-iteration
8899
# and the [Deep Implicit Layers workshop](https://implicit-layers-tutorial.org/implicit_functions/)
100+
101+
# %%
102+
# import jax
103+
# import numpy as np
104+
105+
# def grad_test_fixed_point_forward():
106+
# def jax_g(x, W, b):
107+
# return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
108+
109+
# ndim = 10
110+
# W = jax.random.normal(jax.random.PRNGKey(0), (ndim, ndim)) / jax.numpy.sqrt(ndim)
111+
# b = jax.random.normal(jax.random.PRNGKey(1), (ndim,))
112+
113+
# W, b = np.asarray(W), np.asarray(b) # params
114+
115+
# # gradient of the sum of the outputs with respect to the parameter matrix
116+
# jax_grad = jax.grad(
117+
# lambda W: jax_fixed_point_solver(
118+
# jax_fwd_solver,
119+
# jax_g,
120+
# (W, b), # wrt W
121+
# x0=jax.numpy.zeros_like(b),
122+
# ).sum()
123+
# )(W)
124+
# print(jax_grad[0])
125+
126+
# grad_test_fixed_point_forward()
127+
128+
# # params -> W
129+
# # z -> x
130+
# # x -> b
131+
# # f = lambda W, b, x: jnp.tanh(jnp.dot(W, x) + b)
132+
# # x_star = solver(lambda x: f(params, b, x), x_init=jnp.zeros_like(b))
133+
# # x_star = fixed_point_layer(fwd_solver, f, W, b)
134+
# # g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, b).sum())(W)
135+
# %%
136+
# def implicit_gradients_vjp(solver, f, res, x_soln):
137+
# params, x, x_star = res
138+
# # find adjoint u^T via solver
139+
# # u^T = w^T + u^T \delta_{x_star} f(x_star, params)
140+
# _, vjp_x = jax.vjp(lambda : f(x, *params), x_star) # diff wrt x
141+
# _, vjp_par = jax.vjp(lambda params: f(x, *params), *params) # diff wrt params
142+
# u = solver(lambda u: vjp_x(u)[0] + x_soln, x0=jax.numpy.zeros_like(x_soln))
143+
144+
# # then compute vjp u^T \delta_{params} f(x_star, params)
145+
# return vjp_par(u)

0 commit comments

Comments
 (0)