Skip to content
Discussion options

You must be logged in to vote

You do not use jax.make_jaxpr correctly. It transform a function to another function with jaxpr as its output, and cannot transform a "expression"(Actually, python will eagerly evaluate the expression and pass the value to jax.make_jaxpr, not the expression/AST, and you code actually failed at the evaluation of the vjp expression which need concrete array)
Try following code:

import jax
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray

def func(x, y):
    return x * y

print(jax.make_jaxpr(lambda x, y: jax.vjp(func, x, y))(ShapedArray((3, 4), jnp.float32), ShapedArray((3, 4), jnp.float32)))

and output:

{ lambda ; a:f32[3,4] b:f32[3,4]. let c:f32[3,4] = mul a b in (c, b,…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jacobhess118
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants