Skip to content
Discussion options

You must be logged in to vote

If you want to put stuff together manually in this case you could use jax.vjp, but you'd need to refactor objective_fn and constraint_fn to take y as an argument. Here's what it might look like:

import jax    
from jax import grad, jit    
import jax.numpy as jnp    
    
# Dummy expensive operation    
def expensive_function(x):    
    return jnp.square(x)    
    
# Objective    
def objective_fn(y):    
    return jnp.sum(y)    
    
# Constraint    
def constraint_fn(y):    
    return jnp.sum(y * 2)··    
    
# Inputs    
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32)    
    
# Gradients    
y, vjp_fn = jax.vjp(expensive_function, x)    
objective_ygrad = grad(objecti…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@itk22
Comment options

@davisyoshida
Comment options

Answer selected by itk22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants