-
I would like to implement the following loop:
where Here is a partial implementation where I restrict
I don't quite understand how I could extend to it to add a function like
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 10 replies
-
I don't think there is any way to do this directly with For example, you could write it like this: import jax
import jax.numpy as jnp
def fun0(p, l, _):
x, px, y, py = p
return jnp.asarray([x + px * l, px, y + py * l, py])
def fun1(p, k1, k2):
x, px, y, py = p
return jnp.asarray([x, px - k1 * x, y, py + k2 * y])
def fswitch(p, fa):
fun, *args = fa
return jax.lax.switch(fun.astype(int), [fun0,fun1], p, *args), None
@jax.jit
def loop(p, lst):
return jax.lax.scan(fswitch, p, lst)[0]
dloop = jax.jacfwd(loop)
lst = jnp.asarray([(0, 1.2, 0.0), (1, 0.8, 1.0), (0, 1.2, 0.0), (1, -0.7, 1.0)])
p0 = jnp.asarray([0.3, 0.1, 0.2, 0.4])
p1 = loop(p0, lst)
jac = dloop(p0, lst)
print(jac) |
Beta Was this translation helpful? Give feedback.
-
Thanks! For the realistic use case, I need be flexible on the arguments. Does the list passed to I tried an alternative, but tracer complains on
|
Beta Was this translation helpful? Give feedback.
-
Here the solution: import jax
import jax.numpy as jnp
import numpy as np
def fun0(p, offset, data):
l=data[offset]
x, px, y, py = p
return jnp.asarray([x + px * l, px, y + py * l, py])
def fun1(p, offset,data):
k1=data[offset+0]
k2=data[offset+1]
x, px, y, py = p
return jnp.asarray([x, px - k1 * x, y, py + k2 * y])
def fswitch(pdata, fa):
p,data=pdata
fun, offset = fa
p=jax.lax.switch(fun, [fun0,fun1], p, offset, data)
return (p,data), None
@jax.jit
def loop(p, lst, data):
return jax.lax.scan(fswitch, (p,data), lst)[0][0]
dloop = jax.jacfwd(loop)
p0 = jnp.asarray([0.3, 0.1, 0.2, 0.4])
#first test
lst = jnp.asarray([(0, 0), (1, 1), (0, 0), (1, 3)])
l=1.1;k1=0.1;k2=-0.4;
data= jnp.asarray([l,k1,k1,k2,k2])
jac = dloop(p0, lst, data)
a=np.array([[1,l,0,0],[0,1,0,0],[0,0,1,l],[0,0,0,1]])
b=np.array([[1,0,0,0],[-k1,1,0,0],[0,0,1,0],[0,0, k1,1]])
c=np.array([[1,0,0,0],[-k2,1,0,0],[0,0,1,0],[0,0, k2,1]])
print(c@a@b@a-jac)
# second test
lst = jnp.asarray([(0, 0), (1, 1), (0, 0), (1, 5)])
l=1.2;k1=0.1;k2=-0.5;
data= jnp.asarray([l,k1,k1,0,0,k2,k2])
jac = dloop(p0, lst,data)
a=np.array([[1,l,0,0],[0,1,0,0],[0,0,1,l],[0,0,0,1]])
b=np.array([[1,0,0,0],[-k1,1,0,0],[0,0,1,0],[0,0, k1,1]])
c=np.array([[1,0,0,0],[-k2,1,0,0],[0,0,1,0],[0,0, k2,1]])
print(c@a@b@a-jac) |
Beta Was this translation helpful? Give feedback.
Here the solution: