Skip to content

Commit b2fdb9a

Browse files
committed
docs: improve Nx.conv docs on convolution vs correlation
1 parent cb7fed4 commit b2fdb9a

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

nx/lib/nx.ex

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13007,24 +13007,41 @@ defmodule Nx do
1300713007
1300813008
> #### Convolution vs Correlation {: .tip}
1300913009
>
13010-
> `conv/3` does not perform reversion nor conjugation of the kernel.
13010+
> `conv/3` does not perform reversion of the kernel.
1301113011
> This means that if you come from a Signal Processing background,
13012-
> you might call this operation "correlation" instead of convolution.
13012+
> you might treat it as a cross-correlation operation instead of a convolution.
1301313013
>
13014-
> If you need the proper Signal Processing convolution, you can use
13015-
> `reverse/2` and `conjugate/1`, like in the example:
13014+
> This function is not exactly a cross-correlation function, as it does not
13015+
> perform conjugation of the kernel, as is done in [scipy.signal.correlate](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlate.html).
13016+
> This can be remedied via `Nx.conjugate/1`, as seen below:
1301613017
>
1301713018
> ```elixir
13018-
> axes = Nx.axes(kernel) |> Enum.drop(2)
13019-
>
1302013019
> kernel =
1302113020
> if Nx.Type.complex?(Nx.type(kernel)) do
13022-
> Nx.conjugate(Nx.reverse(kernel, axes: axes))
13021+
> Nx.conjugate(kernel)
1302313022
> else
13024-
> Nx.reverse(kernel, axes: axes)
13023+
> kernel
13024+
> end
13025+
>
13026+
> Nx.conv(tensor, kernel)
13027+
> ```
13028+
>
13029+
> If you need the proper Signal Processing convolution, such as the one in
13030+
> [scipy.signal.convolve](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve.html),
13031+
> you can use `reverse/2`, like in the example:
13032+
>
13033+
> ```elixir
13034+
> reversal_axes =
13035+
> case Nx.rank(kernel) do
13036+
> 0 -> []
13037+
> 1 -> [1]
13038+
> 2 -> [0, 1]
13039+
> _ -> Enum.drop(Nx.axes(kernel), 2)
1302513040
> end
1302613041
>
13027-
> Nx.conv(img, kernel)
13042+
> kernel = Nx.reverse(kernel, axes: reversal_axes)
13043+
>
13044+
> Nx.conv(tensor, kernel)
1302813045
> ```
1302913046
1303013047
## Examples

0 commit comments

Comments
 (0)