1
- import functools as ft
2
-
3
1
import jax
4
2
import numpy as np
5
3
from numpy .testing import assert_array_almost_equal
11
9
jax .config .update ("jax_enable_x64" , True )
12
10
13
11
14
- def jax_newton_solver (f , z_init ):
15
- def f_root (z ):
16
- return f (z ) - z
12
+ def jax_newton_solver (f , x0 ):
13
+ def f_root (x ):
14
+ return f (x ) - x
15
+
16
+ def g (x ):
17
+ return x - jax .numpy .linalg .solve (jax .jacobian (f_root )(x ), f_root (x ))
18
+
19
+ return jax_fwd_solver (g , x0 )
17
20
18
- def g (z ):
19
- return z - jax .numpy .linalg .solve (jax .jacobian (f_root )(z ), f_root (z ))
20
21
21
- return jax_fwd_solver (g , z_init )
22
+ def jax_fwd_solver (f , x0 , tol = 1e-5 ):
23
+ x_prev , x = x0 , f (x0 )
24
+ while jax .numpy .linalg .norm (x_prev - x ) > tol :
25
+ x_prev , x = x , f (x )
26
+ return x
22
27
23
28
24
- def jax_fwd_solver (f , z_init , tol = 1e-5 ):
25
- z_prev , z = z_init , f (z_init )
26
- while jax .numpy .linalg .norm (z_prev - z ) > tol :
27
- z_prev , z = z , f (z )
28
- return z
29
+ def jax_fixed_point_solver (solver , f , params , x0 , ** solver_kwargs ):
30
+ x_star = solver (lambda x : f (x , * params ), x0 = x0 , ** solver_kwargs )
31
+ return x_star
29
32
30
33
31
34
def test_fixed_point_forward ():
@@ -34,7 +37,7 @@ def test_fixed_point_forward():
34
37
def g (x , W , b ):
35
38
return pt .tanh (pt .dot (W , x ) + b )
36
39
37
- def _jax_g (x , W , b ):
40
+ def jax_g (x , W , b ):
38
41
return jax .numpy .tanh (jax .numpy .dot (W , x ) + b )
39
42
40
43
ndim = 10
@@ -43,9 +46,13 @@ def _jax_g(x, W, b):
43
46
44
47
W , b = np .asarray (W ), np .asarray (b )
45
48
46
- jax_g = ft .partial (_jax_g , W = W , b = b )
49
+ jax_solution = jax_fixed_point_solver (
50
+ jax_fwd_solver ,
51
+ jax_g ,
52
+ (W , b ),
53
+ x0 = jax .numpy .zeros_like (b ),
54
+ )
47
55
48
- jax_solution = jax_fwd_solver (jax_g , jax .numpy .zeros_like (b ))
49
56
pytensor_solution , _ = fixed_point_solver (
50
57
g ,
51
58
fwd_solver ,
@@ -60,7 +67,7 @@ def test_fixed_point_newton():
60
67
def g (x , W , b ):
61
68
return pt .tanh (pt .dot (W , x ) + b )
62
69
63
- def _jax_g (x , W , b ):
70
+ def jax_g (x , W , b ):
64
71
return jax .numpy .tanh (jax .numpy .dot (W , x ) + b )
65
72
66
73
ndim = 10
@@ -69,9 +76,13 @@ def _jax_g(x, W, b):
69
76
70
77
W , b = np .asarray (W ), np .asarray (b )
71
78
72
- jax_g = ft .partial (_jax_g , W = W , b = b )
79
+ jax_solution = jax_fixed_point_solver (
80
+ jax_newton_solver ,
81
+ jax_g ,
82
+ (W , b ),
83
+ x0 = jax .numpy .zeros_like (b ),
84
+ )
73
85
74
- jax_solution = jax_newton_solver (jax_g , jax .numpy .zeros_like (b ))
75
86
pytensor_solution , _ = fixed_point_solver (
76
87
g ,
77
88
newton_solver ,
@@ -86,3 +97,49 @@ def _jax_g(x, W, b):
86
97
# and adjoint implicit function theorem rewritten grad
87
98
# see the [notes](https://theorashid.github.io/notes/fixed-point-iteration
88
99
# and the [Deep Implicit Layers workshop](https://implicit-layers-tutorial.org/implicit_functions/)
100
+
101
+ # %%
102
+ # import jax
103
+ # import numpy as np
104
+
105
+ # def grad_test_fixed_point_forward():
106
+ # def jax_g(x, W, b):
107
+ # return jax.numpy.tanh(jax.numpy.dot(W, x) + b)
108
+
109
+ # ndim = 10
110
+ # W = jax.random.normal(jax.random.PRNGKey(0), (ndim, ndim)) / jax.numpy.sqrt(ndim)
111
+ # b = jax.random.normal(jax.random.PRNGKey(1), (ndim,))
112
+
113
+ # W, b = np.asarray(W), np.asarray(b) # params
114
+
115
+ # # gradient of the sum of the outputs with respect to the parameter matrix
116
+ # jax_grad = jax.grad(
117
+ # lambda W: jax_fixed_point_solver(
118
+ # jax_fwd_solver,
119
+ # jax_g,
120
+ # (W, b), # wrt W
121
+ # x0=jax.numpy.zeros_like(b),
122
+ # ).sum()
123
+ # )(W)
124
+ # print(jax_grad[0])
125
+
126
+ # grad_test_fixed_point_forward()
127
+
128
+ # # params -> W
129
+ # # z -> x
130
+ # # x -> b
131
+ # # f = lambda W, b, x: jnp.tanh(jnp.dot(W, x) + b)
132
+ # # x_star = solver(lambda x: f(params, b, x), x_init=jnp.zeros_like(b))
133
+ # # x_star = fixed_point_layer(fwd_solver, f, W, b)
134
+ # # g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, b).sum())(W)
135
+ # %%
136
+ # def implicit_gradients_vjp(solver, f, res, x_soln):
137
+ # params, x, x_star = res
138
+ # # find adjoint u^T via solver
139
+ # # u^T = w^T + u^T \delta_{x_star} f(x_star, params)
140
+ # _, vjp_x = jax.vjp(lambda : f(x, *params), x_star) # diff wrt x
141
+ # _, vjp_par = jax.vjp(lambda params: f(x, *params), *params) # diff wrt params
142
+ # u = solver(lambda u: vjp_x(u)[0] + x_soln, x0=jax.numpy.zeros_like(x_soln))
143
+
144
+ # # then compute vjp u^T \delta_{params} f(x_star, params)
145
+ # return vjp_par(u)
0 commit comments