Skip to content
Discussion options

You must be logged in to vote

Here the solution:

import jax
import jax.numpy as jnp
import numpy as np

def fun0(p, offset, data):
    l=data[offset]
    x, px, y, py = p
    return jnp.asarray([x + px * l, px, y + py * l, py])

def fun1(p, offset,data):
    k1=data[offset+0]
    k2=data[offset+1]
    x, px, y, py = p
    return jnp.asarray([x, px - k1 * x, y, py + k2 * y])

def fswitch(pdata, fa):
    p,data=pdata
    fun, offset = fa
    p=jax.lax.switch(fun, [fun0,fun1], p, offset, data)
    return (p,data), None

@jax.jit
def loop(p, lst, data):
    return jax.lax.scan(fswitch, (p,data), lst)[0][0]

dloop = jax.jacfwd(loop)

p0 = jnp.asarray([0.3, 0.1, 0.2, 0.4])

#first test
lst = jnp.asarray([(0, 0), (1, 1), (0, 0

Replies: 3 comments 10 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
9 replies
@rdemaria
Comment options

@jakevdp
Comment options

@rdemaria
Comment options

@YouJiacheng
Comment options

@rdemaria
Comment options

Comment options

You must be logged in to vote
1 reply
@YouJiacheng
Comment options

Answer selected by rdemaria
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants