Skip to content

Commit cb7fed4

Browse files
docs: improve docs for Nx.conv/3 (#1564)
Co-authored-by: José Valim <[email protected]>
1 parent ed7a3b1 commit cb7fed4

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

nx/guides/advanced/aggregation.livemd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ m = ~MAT[
9393
>
9494
```
9595

96-
First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](<https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the>), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.
96+
First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.
9797

9898
```elixir
9999
w = ~MAT[

nx/lib/nx.ex

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12919,6 +12919,11 @@ defmodule Nx do
1291912919
of summing the element-wise products in the window across
1292012920
each input channel.
1292112921
12922+
> #### Kernel Reflection {: .info}
12923+
>
12924+
> See the note at the end of this section for more details
12925+
> on the convention for kernel reflection and conjugation.
12926+
1292212927
The ranks of both `input` and `kernel` must match. By
1292312928
default, both `input` and `kernel` are expected to have shapes
1292412929
of the following form:
@@ -13000,6 +13005,28 @@ defmodule Nx do
1300013005
in the same way as with `:feature_group_size`, however, the input
1300113006
tensor will be split into groups along the batch dimension.
1300213007
13008+
> #### Convolution vs Correlation {: .tip}
13009+
>
13010+
> `conv/3` does not perform reversion nor conjugation of the kernel.
13011+
> This means that if you come from a Signal Processing background,
13012+
> you might call this operation "correlation" instead of convolution.
13013+
>
13014+
> If you need the proper Signal Processing convolution, you can use
13015+
> `reverse/2` and `conjugate/1`, like in the example:
13016+
>
13017+
> ```elixir
13018+
> axes = Nx.axes(kernel) |> Enum.drop(2)
13019+
>
13020+
> kernel =
13021+
> if Nx.Type.complex?(Nx.type(kernel)) do
13022+
> Nx.conjugate(Nx.reverse(kernel, axes: axes))
13023+
> else
13024+
> Nx.reverse(kernel, axes: axes)
13025+
> end
13026+
>
13027+
> Nx.conv(img, kernel)
13028+
> ```
13029+
1300313030
## Examples
1300413031
1300513032
iex> left = Nx.iota({1, 1, 3, 3})

nx/lib/nx/defn/kernel.ex

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,8 +1398,8 @@ defmodule Nx.Defn.Kernel do
13981398
## Named hooks
13991399
14001400
It is possible to give names to the hooks. This allows them
1401-
to be defined or overridden by calling `Nx.Defn.jit/2` or
1402-
`Nx.Defn.stream/2`. Let's see an example:
1401+
to be defined or overridden by calling `Nx.Defn.jit/2`.
1402+
Let's see an example:
14031403
14041404
defmodule Hooks do
14051405
import Nx.Defn
@@ -1437,9 +1437,8 @@ defmodule Nx.Defn.Kernel do
14371437
{add, mult}
14381438
end
14391439
1440-
If a hook with the same name is given to `Nx.Defn.jit/2`
1441-
or `Nx.Defn.stream/2`, then it will override the default
1442-
callback.
1440+
If a hook with the same name is given to `Nx.Defn.jit/2`,
1441+
then it will override the default callback.
14431442
14441443
## Hooks and tokens
14451444

0 commit comments

Comments
 (0)