conv_transpose2D
with groups
option in JAX
#9887
-
Hello, I wanted to know if there was a correct (and clever) way to implement the I am trying to implement an invertible transform using convolutions. In particular a wavelet kernel (demo paper | code). The authors used a nifty kernel which is similar to the haar matrix in conjunction with the I have a demo colab notebook here to give a bit better context. The forward works fine but the inverse is completely incorrect. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
I think you can leverage import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
x = jax.random.uniform(jax.random.PRNGKey(0), (1, 3, 512, 512))
kernel = np.ones((4, 1, 2, 2))
kernel[1, 0, 0, 1] = -1
kernel[1, 0, 1, 1] = -1
kernel[2, 0, 1, 0] = -1
kernel[2, 0, 1, 1] = -1
kernel[3, 0, 1, 0] = -1
kernel[3, 0, 0, 1] = -1
kernel *= 0.5
kernel = np.concatenate([kernel] * 3, 0)
kerner = jnp.asarray(kernel, dtype=jnp.float32)
def fwd(x):
dn = lax.conv_dimension_numbers(
x.shape, # only ndim matters, not shape
kerner.shape, # only ndim matters, not shape
('NCHW', 'OIHW', 'NCHW'),
)
return lax.conv_general_dilated(
lhs=x, # lhs = image tensor
rhs=kerner, # rhs = conv kernel tensor
window_strides=(2, 2), # window strides
padding='VALID', # padding mode
lhs_dilation=(1, 1), # lhs/image dilation
rhs_dilation=(1, 1), # rhs/kernel dilation
dimension_numbers=dn,
feature_group_count=3
)
x_squeeze = fwd(x)
inv = jax.linear_transpose(fwd, x) # only use x.shape and x.dtype here
x_ori, = inv(x_squeeze) # notice the comma
print(x.shape, x_ori.shape)
print(jnp.max(lax.abs(x_ori - x))) output
|
Beta Was this translation helpful? Give feedback.
-
There is |
Beta Was this translation helpful? Give feedback.
I think you can leverage
jax.linear_transpose
, and never need to struggle with how to transpose a linear function!