Skip to content
Discussion options

You must be logged in to vote

I think something like this will work:

from jax import jit
from jax.tree_util import Partial

th = (Partial(lambda params, x: params['x'] + x), dict(x=1))

@jit
def app(t, x):
  f, params = t
  return f(params, x)

app(th, 1)

Replies: 1 comment 10 replies

Comment options

You must be logged in to vote
10 replies
@YouJiacheng
Comment options

@shoyer
Comment options

shoyer Mar 11, 2022
Collaborator

@mchagneux
Comment options

@shoyer
Comment options

shoyer Mar 11, 2022
Collaborator

@mchagneux
Comment options

Answer selected by mchagneux
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants