InconclusiveDimensionOperation with complicated vjp #6966
-
I'm trying to run a vjp on a function that concatenates (then flattens) multiple copies of a vector: from jax import numpy as npj
from jax import vjp, vmap
N0 = 15
N1 = 10
def mapping(x):
return npj.vstack([x for k in range(N1)]).flatten()
x = npj.ones((N0,))
dJ_du = 0.5*npj.ones((N0*N1,))
y, vjp_fun = vjp(mapping, x)
mg, = vmap(vjp_fun)(dJ_du)
print(mg) I get an
Note we can do It seems like something is off with the dimension trace (I admit, the forward function is a bit wonky). Is this expected (or is there a way around this)? I realize my MWE seems rather trivial and there are much better ways to do the same thing. However, I'm hoping to extend the functionality (e.g. such that each row is uniquely transformed). |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Good question – I think the issue is that your input to y, vjp_fun = vjp(mapping, x)
print(vjp_fun(dJ_du)) But if you vmap the Does that answer your question? |
Beta Was this translation helpful? Give feedback.
Good question – I think the issue is that your input to
vjp_fun
is not of the expected size. The VJP is a backward pass, so it expects an input with shape/size corresponding to the cotangent; since you've defineddJ_du
is the same shape asy
, this should work:But if you vmap the
vjp_fun
before calling it on a one-dimensional input, it is the equivalent of passing a scalar tovjp_fun
, which will fail.Does that answer your question?