Which function should we prefer for CNN applications #8323
-
I am new to JAX and while implementing convolution 2D (strides, padding inclusive), have come across a number of functions. Like:
Also:
So, inevitably I am confused here as PyTorch or Keras have simply a single 2D convolution counterpart respectively. Which of them (or is there some there too) should I use to get similar functionality and parameters support? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In general, If you inspect their source code, both Flax, haiku and objax are frameworks helpful to construct and use a neural network, but don't really provide any basic functionality. Essentially, just use anything that works for you at first. They are all most likely equivalent. If you want to be sure about it, use |
Beta Was this translation helpful? Give feedback.
In general,
jax.numpy
andjax.scipy
are built on top ofjax.lax
functions, and they are merely syntactic sugar forjax.lax
functions that are usually a bit harder to use and have stricter syntax.Also, they mimick the scipy/numpy api which is more widely known and in general easier to use.
If you inspect their source code, both
jax.scipy.signal.convolve/2d
call intojax.lax.conv_general_dilated
: https://jax.readthedocs.io/en/latest/_modules/jax/_src/scipy/signal.html#_convolve_ndFlax, haiku and objax are frameworks helpful to construct and use a neural network, but don't really provide any basic functionality.
They are all built on top of bare jax functions (in particular,
jax.lax.conv_…