Skip to content
Discussion options

You must be logged in to vote

https://github.com/google/jax/blob/main/jax/interpreters/partial_eval.py
I think this tool can help you gracefully achieve your goal.
jax use it to implement linearize as follow:

You can declare x as known value then use pe.trace_to_jaxpr.
WDYT?

Replies: 4 comments 17 replies

Comment options

You must be logged in to vote
3 replies
@leakec
Comment options

@leakec
Comment options

@YouJiacheng
Comment options

Comment options

You must be logged in to vote
12 replies
@mariogeiger
Comment options

@leakec
Comment options

@leakec
Comment options

@mariogeiger
Comment options

@leakec
Comment options

Answer selected by leakec
Comment options

You must be logged in to vote
2 replies
@YouJiacheng
Comment options

@YouJiacheng
Comment options

Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants