Skip to content
Discussion options

You must be logged in to vote

Hi, digging a little I have found the problem on my JAX code: it comes from the way I introducing the
action_fn and force_fn as

@partial(jax.jit,static_argnums=(2,3,4))
def hmc_one_step(key,phi,action_fn,force_fn,n_steps=10,dt=0.1):

I was force to use static arguments for action_fn, force_fn. I have redesigned the code to avoid such static_argums
and now the JAX-JIT code is 3 times faster than the corresponding torch code.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jecampagne
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant