How is regular 2-d convolution layer implemented in JAX(STAX)? #6167
-
Dear all, Hi, as I mentioned here, I'm searching for a function in JAX that does the same thing as I implemented functions with So I'm now interested in how is regular 2-d convolutional layer implemented in Thanks in advance!:) Best, P.S. In case anyone is interested in why I need that: What I'm actually doing is to add some additional operations to the regular convolution (which can be seen as a sum of element-wise multiplication), and build a 'new convolution'. So my idea is to transform the image into large arrays, flatten the weights and then do my new operation based on the vectorized operations. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Problem already solved. Sorry for my stupid question again. And many thanks to @mattjj !!! :) |
Beta Was this translation helpful? Give feedback.
Problem already solved. Sorry for my stupid question again. And many thanks to @mattjj !!! :)