HOWTO: differentiable functions that depend on static arguments? #7278
Unanswered
smao-astro
asked this question in
Q&A
Replies: 1 comment
-
Partials are OK. You can also create a Module-like structure: @dataclass
class G:
a: float
def __call__(self, x):
return self.a + x
g = jax.jit(G(a=1)) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am new to JAX and I am looking for suggestions that which is the best way to generate pure functions that depend on static arguments.
For example (over-simplified), I want to get a function
f(x)
, that do math calculation:a*x**2+b*x+c
, anda, b, c
here are constants and does not change in one execution (depend on the command line input to the program); and I also want to have easy-to-write-and-read jacobians off(x)
, such aswithout considering the
a, b, c
.I have got two way to do so, but I am wondering: 1. what is the difference of these two way. 2. which one is better. Please give me some help!
The first way:
The second way:
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions