jax.vjp() with abstract-only primals (i.e. ShapedArray
)?
#9779
-
Wondering is there a way to make import jax
from jax.abstract_arrays import ShapedArray
import numpy as np
def func(x, y):
return x * y
print(jax.make_jaxpr(
jax.vjp(
func,
ShapedArray([3, 4], np.float32),
ShapedArray([3, 4], np.float32)
)
)) But it throws error:
Ideally I would like to use this approach to get a Jaxpr-based graph for the backward pass of The reason I want to avoid needing to pass in Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You do not use import jax
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray
def func(x, y):
return x * y
print(jax.make_jaxpr(lambda x, y: jax.vjp(func, x, y))(ShapedArray((3, 4), jnp.float32), ShapedArray((3, 4), jnp.float32))) and output:
However, you want:
Thus you should use following code: import jax
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray
def func(x, y):
return x * y
abs_arr = ShapedArray((3, 4), jnp.float32)
print(jax.make_jaxpr(lambda x, y, grad_out: jax.vjp(func, x, y)[1](grad_out))(abs_arr, abs_arr, abs_arr)) and output:
exactly what you need. |
Beta Was this translation helpful? Give feedback.
You do not use
jax.make_jaxpr
correctly. It transform a function to another function with jaxpr as its output, and cannot transform a "expression"(Actually, python will eagerly evaluate the expression and pass the value tojax.make_jaxpr
, not the expression/AST, and you code actually failed at the evaluation of the vjp expression which need concrete array)Try following code:
and output: