Replies: 1 comment
-
Hi - I think you're right that the documentation is incomplete when it comes to support of general pytree arguments. This is probably because it's difficult to capture the full flexibility of the API in a docstring that's simple enough to be clear for the most common cases. Would you like to put together a pull request to address this? If not someone from the team can take a look. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I am currently trying to reproduce the fundamentals of equinox in a minimal example. For this, I took another look at the documentation of
jax.grad
and have two questions.The documentation states for the
fun
argument:More generally,
fun
s arguments at positions specified byargnums
can be pytrees with inexact arrays and scalars as leave nodes, correct? If yes, should this be added to the documentation? I would be able to open a quick PR.The documentation states for the return value:
Consider the following simple example:
with output
Especially, the gradient has the same type (
MyTree
) as the positional argument indicated byargnums
. Still, the type of the leave nodex
changed fromfloat
anjax.Array
of dtypefloat32
. Is this close enough or should such "arrayifications" be mentioned by the documentation?Beta Was this translation helpful? Give feedback.
All reactions