|
| 1 | +# Broadcasting |
| 2 | + |
| 3 | +The dimensions of tensors in an operator don't always match. |
| 4 | +For example, you might want to subtract a `1` from every |
| 5 | +element of a `{2, 2}`-shaped tensor, like this: |
| 6 | + |
| 7 | +$$ |
| 8 | +\begin{bmatrix} |
| 9 | + 1 & 2 \\\\ |
| 10 | + 3 & 4 |
| 11 | +\end{bmatrix} - 1 = |
| 12 | +\begin{bmatrix} |
| 13 | + 0 & 1 \\\\ |
| 14 | + 2 & 3 |
| 15 | +\end{bmatrix} |
| 16 | +$$ |
| 17 | + |
| 18 | +Mathematically, this is the same as: |
| 19 | + |
| 20 | +$$ |
| 21 | +\begin{bmatrix} |
| 22 | + 1 & 2 \\\\ |
| 23 | + 3 & 4 |
| 24 | +\end{bmatrix} - |
| 25 | +\begin{bmatrix} |
| 26 | + 1 & 1 \\\\ |
| 27 | + 1 & 1 |
| 28 | +\end{bmatrix} = |
| 29 | +\begin{bmatrix} |
| 30 | + 0 & 1 \\\\ |
| 31 | + 2 & 3 |
| 32 | +\end{bmatrix} |
| 33 | +$$ |
| 34 | + |
| 35 | +This means we need a way to convert `1` to a `{2, 2}`-shaped tensor. |
| 36 | +`Nx.broadcast/2` solves that problem. This function takes |
| 37 | +a tensor or a scalar and a shape. |
| 38 | + |
| 39 | +```elixir |
| 40 | +Mix.install([ |
| 41 | + {:nx, "~> 0.9"} |
| 42 | +]) |
| 43 | + |
| 44 | + |
| 45 | +Nx.broadcast(1, {2, 2}) |
| 46 | +``` |
| 47 | + |
| 48 | +This call takes the scalar `1` and translates it |
| 49 | +to a compatible shape by copying it. Sometimes, it's easier |
| 50 | +to provide a tensor as the second argument, and let `broadcast/2` |
| 51 | +extract its shape: |
| 52 | + |
| 53 | +```elixir |
| 54 | +tensor = Nx.tensor([[1, 2], [3, 4]]) |
| 55 | +Nx.broadcast(1, tensor) |
| 56 | +``` |
| 57 | + |
| 58 | +The code broadcasts `1` to the shape of `tensor`. In many operators |
| 59 | +and functions, the broadcast happens automatically: |
| 60 | + |
| 61 | +```elixir |
| 62 | +Nx.subtract(tensor, 1) |
| 63 | +``` |
| 64 | + |
| 65 | +This result is possible because Nx broadcasts _both tensors_ |
| 66 | +in `subtract/2` to compatible shapes. That means you can provide |
| 67 | +scalar values as either argument: |
| 68 | + |
| 69 | +```elixir |
| 70 | +Nx.subtract(10, tensor) |
| 71 | +``` |
| 72 | + |
| 73 | +Or subtract a row or column. Mathematically, it would look like this: |
| 74 | + |
| 75 | +$$ |
| 76 | +\begin{bmatrix} |
| 77 | + 1 & 2 \\\\ |
| 78 | + 3 & 4 |
| 79 | +\end{bmatrix} - |
| 80 | +\begin{bmatrix} |
| 81 | + 1 & 2 |
| 82 | +\end{bmatrix} = |
| 83 | +\begin{bmatrix} |
| 84 | + 0 & 0 \\\\ |
| 85 | + 2 & 2 |
| 86 | +\end{bmatrix} |
| 87 | +$$ |
| 88 | + |
| 89 | +which is the same as this: |
| 90 | + |
| 91 | +$$ |
| 92 | +\begin{bmatrix} |
| 93 | + 1 & 2 \\\\ |
| 94 | + 3 & 4 |
| 95 | +\end{bmatrix} - |
| 96 | +\begin{bmatrix} |
| 97 | + 1 & 2 \\\\ |
| 98 | + 1 & 2 |
| 99 | +\end{bmatrix} = |
| 100 | +\begin{bmatrix} |
| 101 | + 0 & 0 \\\\ |
| 102 | + 2 & 2 |
| 103 | +\end{bmatrix} |
| 104 | +$$ |
| 105 | + |
| 106 | +This rewrite happens in Nx as well, through a broadcast operation. We want to |
| 107 | +broadcast the tensor `[1, 2]` to match the `{2, 2}` shape: |
| 108 | + |
| 109 | +```elixir |
| 110 | +Nx.broadcast(Nx.tensor([1, 2]), {2, 2}) |
| 111 | +``` |
| 112 | + |
| 113 | +The `subtract` function in `Nx` takes care of that broadcast |
| 114 | +implicitly, as discussed above: |
| 115 | + |
| 116 | +```elixir |
| 117 | +Nx.subtract(tensor, Nx.tensor([1, 2])) |
| 118 | +``` |
| 119 | + |
| 120 | +The broadcast worked as expected, copying the `[1, 2]` row |
| 121 | +enough times to fill a `{2, 2}`-shaped tensor. A tensor with a |
| 122 | +dimension of `1` will broadcast to fill the tensor: |
| 123 | + |
| 124 | +```elixir |
| 125 | +[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2}) |
| 126 | +``` |
| 127 | + |
| 128 | +```elixir |
| 129 | +[[[1, 2, 3]]] |
| 130 | +|> Nx.tensor() |
| 131 | +|> Nx.broadcast({4, 2, 3}) |
| 132 | +``` |
| 133 | + |
| 134 | +Both of these examples copy parts of the tensor enough |
| 135 | +times to fill out the broadcast shape. You can check out the |
| 136 | +Nx broadcasting documentation for more details: |
| 137 | + |
| 138 | +<!-- livebook:{"disable_formatting":true} --> |
| 139 | + |
| 140 | +```elixir |
| 141 | +h Nx.broadcast |
| 142 | +``` |
| 143 | + |
| 144 | +Much of the time, you won't have to broadcast yourself. Many of |
| 145 | +the functions and operators Nx supports will do so automatically. |
| 146 | + |
| 147 | +We can use tensor-aware operators via various `Nx` functions and |
| 148 | +many of them implicitly broadcast tensors. |
| 149 | + |
| 150 | +Throughout this section, we have been invoking `Nx.subtract/2` and |
| 151 | +our code would be more expressive if we could use its equivalent |
| 152 | +mathematical operator. Fortunately, Nx provides a way. Next, we'll |
| 153 | +dive into numerical definitions using `defn`. |
0 commit comments