-
EDIT: What I was trying to implement was locally connected layer. The convolution works fine.Yes, its a stupid question but I cannot make it work despite the fact that I put unreasonable amount of time reading docs. It seems like that Convolutions in JAX document only gives example for applying single kernel to every part of an image, which is not the case in deep learning convolution layer.
In Convolutions in JAX, it describes the meaning of each alphabet as: Below is a working example: import jax
import jax.numpy as jnp
x = jnp.ones((1, 3, 100, 100), dtype='float32') # NCHW
w = jnp.ones((7, 3, 5, 5), dtype='float32') # OIWH
out1 = jax.lax.conv(x, w, window_strides=(1, 1), padding=((2, 2), (2, 2))) # works same with padding='SAME'
print(out1.shape)
# result: (1, 7, 100, 100) However, my problem is that convolution in deep learning requires different kernels for different part of a image.
What really boggles me is that I think I solved this problem before. I might am completely missing the point but I have spent way too much time in this. It will be great to get an answer. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
It's not normally the case in deep learning that different parts of an image get different convolution kernels. Instead, the same kernel is advanced across the input image, and that is the operation that Can you point to an example of what you're trying to implement in another deep learning system? I don't think the operation you have described exists in any of the standard systems. |
Beta Was this translation helpful? Give feedback.
-
It turns out that what I was trying to implement was locally connected layer. import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
@jax.jit
def locally_connected_2d(x, w, window_strides=(1, 1)):
patches = jax.lax.conv_general_dilated_patches(
lhs=x.reshape(1, *x.shape),
filter_shape=(w.shape[3], w.shape[4]),
window_strides=window_strides,
padding='SAME',
dimension_numbers = ('NHWC', 'OIHW', 'NHWC')
)
w = w.reshape(w.shape[0], w.shape[1], w.shape[2], -1)
patches = patches.reshape(patches.shape[1], patches.shape[2], -1)
return jax.lax.dot_general(w, patches, dimension_numbers=(([3], [2]), ([0, 1], [0, 1])))
x = jnp.arange(100*100*5, dtype='float32').reshape(100, 100, 5)
w = jnp.arange(100*100*3*7*7*5, dtype='float32').reshape(100, 100, 3, 7, 7, 5)
print("x.shape=", x.shape)
print("w.shape=", w.shape)
r = locally_connected_2d(x, w)
print("r.shape=", r.shape)
plt.imshow(np.array(r/np.max(r)*255, dtype='uint8'))
plt.show()
%timeit locally_connected_2d(x, w).block_until_ready()
# x.shape= (100, 100, 5)
# w.shape= (100, 100, 3, 7, 7, 5)
# r.shape= (100, 100, 3)
# 1000 loops, best of 5: 607 µs per loop |
Beta Was this translation helpful? Give feedback.
It's not normally the case in deep learning that different parts of an image get different convolution kernels. Instead, the same kernel is advanced across the input image, and that is the operation that
lax.conv
performs.Can you point to an example of what you're trying to implement in another deep learning system? I don't think the operation you have described exists in any of the standard systems.