-
I have tried to evaluate the shape of import jax
from jax import lax
from functools import partial
x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32)
kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32)
# does not work
jax.eval_shape(lax.conv_general_dilated, x, kernel, (1, 1), "SAME")
# use a custom jvp to declare which args are non-diff
@partial(jax.custom_jvp, nondiff_argnums=(2, 3))
def convolve_wrap(x, kernel, stride, padding):
"""Convolve 2D."""
return lax.conv_general_dilated(
x,
kernel,
window_strides=(stride, stride),
padding=padding,
)
# does not work either
jax.eval_shape(convolve_wrap, x, kernel, (1, 1), "SAME") Both fail with:
I'd be happy for any hint on how to use it correctly. Thanks! |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Dec 17, 2023
Replies: 1 comment 1 reply
-
All arguments of jax.eval_shape(
partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME"),
x, kernel) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
adonath
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
All arguments of
jax.eval_shape
will be treated as dynamic; any static arguments should be bound to the function using a closure or a partial. For example: