Skip to content
Discussion options

You must be logged in to vote

Hi. As far as I understand, this is not the case. To be able to calculate the value and gradient of a function, all you need is a 'recipe' for differentiation, which is either inferred from the computational graph or defined with a custom_vjp. I attached an example of a function which cannot be jitted because it uses calls to original numpy, but can be differentiated because a custom_vjp is defined:

import jax
import numpy as np
import jax.numpy as jnp
from jax import custom_vjp

@custom_vjp
def f(x):
    # The function deliberately uses numpy functions which are not jittable
    return -(x[0] + np.sin(x[0])) * np.exp(-x[0]**2.0)


def f_fwd(x):
    return f(x), (x, )


def f_bwd(res, g):…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@ayaka14732
Comment options

Answer selected by ayaka14732
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants