Skip to content
Discussion options

You must be logged in to vote

You could write it this way if you like:

def grad_n(n: int, f, **kwargs):
  for i in range(n):
    f = grad(f, **kwargs)
  return f

I think we'd be unlikely to add such a helper function to the JAX API, since it's straightforward to define a function like this if you need it.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by exenGT
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants