Skip to content

Commit e58062b

Browse files
committed
docs: getting started section
1 parent 7335990 commit e58062b

File tree

3 files changed

+276
-0
lines changed

3 files changed

+276
-0
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Broadcasts
2+
3+
Often, the dimensions of tensors in an operator don't match.
4+
For example, you might want to subtract a `1` from every
5+
element of a `{2, 2}` tensor, like this:
6+
7+
$$
8+
\begin{bmatrix}
9+
1 & 2 \\\\
10+
3 & 4
11+
\end{bmatrix} - 1 =
12+
\begin{bmatrix}
13+
0 & 1 \\\\
14+
2 & 3
15+
\end{bmatrix}
16+
$$
17+
18+
Mathematically, it's the same as this:
19+
20+
$$
21+
\begin{bmatrix}
22+
1 & 2 \\\\
23+
3 & 4
24+
\end{bmatrix} -
25+
\begin{bmatrix}
26+
1 & 1 \\\\
27+
1 & 1
28+
\end{bmatrix} =
29+
\begin{bmatrix}
30+
0 & 1 \\\\
31+
2 & 3
32+
\end{bmatrix}
33+
$$
34+
35+
That means we need a way to convert `1` to a `{2, 2}` tensor.
36+
`Nx.broadcast/2` solves that problem. This function takes
37+
a tensor or a scalar and a shape.
38+
39+
```elixir
40+
Mix.install([
41+
{:nx, "~> 0.5"}
42+
])
43+
44+
45+
Nx.broadcast(1, {2, 2})
46+
```
47+
48+
This broadcast takes the scalar `1` and translates it
49+
to a compatible shape by copying it. Sometimes, it's easier
50+
to provide a tensor as the second argument, and let `broadcast/2`
51+
extract its shape:
52+
53+
```elixir
54+
tensor = Nx.tensor([[1, 2], [3, 4]])
55+
Nx.broadcast(1, tensor)
56+
```
57+
58+
The code broadcasts `1` to the shape of `tensor`. In many operators
59+
and functions, the broadcast happens automatically:
60+
61+
```elixir
62+
Nx.subtract(tensor, 1)
63+
```
64+
65+
This result is possible because Nx broadcasts _both tensors_
66+
in `subtract/2` to compatible shapes. That means you can provide
67+
scalar values as either argument:
68+
69+
```elixir
70+
Nx.subtract(10, tensor)
71+
```
72+
73+
Or subtract a row or column. Mathematically, it would look like this:
74+
75+
$$
76+
\begin{bmatrix}
77+
1 & 2 \\\\
78+
3 & 4
79+
\end{bmatrix} -
80+
\begin{bmatrix}
81+
1 & 2
82+
\end{bmatrix} =
83+
\begin{bmatrix}
84+
0 & 0 \\\\
85+
2 & 2
86+
\end{bmatrix}
87+
$$
88+
89+
which is the same as this:
90+
91+
$$
92+
\begin{bmatrix}
93+
1 & 2 \\\\
94+
3 & 4
95+
\end{bmatrix} -
96+
\begin{bmatrix}
97+
1 & 2 \\\\
98+
1 & 2
99+
\end{bmatrix} =
100+
\begin{bmatrix}
101+
0 & 0 \\\\
102+
2 & 2
103+
\end{bmatrix}
104+
$$
105+
106+
This rewrite happens in Nx too, also through a broadcast. We want to
107+
broadcast the tensor `[1, 2]` to match the `{2, 2}` shape, like this:
108+
109+
```elixir
110+
Nx.broadcast(Nx.tensor([1, 2]), {2, 2})
111+
```
112+
113+
The `subtract` function in `Nx` takes care of that broadcast
114+
implicitly, as before:
115+
116+
```elixir
117+
Nx.subtract(tensor, Nx.tensor([1, 2]))
118+
```
119+
120+
The broadcast worked as advertised, copying the `[1, 2]` row
121+
enough times to fill a `{2, 2}` tensor. A tensor with a
122+
dimension of `1` will broadcast to fill the tensor:
123+
124+
```elixir
125+
[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2})
126+
```
127+
128+
```elixir
129+
[[[1, 2, 3]]]
130+
|> Nx.tensor()
131+
|> Nx.broadcast({4, 2, 3})
132+
```
133+
134+
Both of these examples copy parts of the tensor enough
135+
times to fill out the broadcast shape. You can check out the
136+
Nx broadcasting documentation for more details:
137+
138+
<!-- livebook:{"disable_formatting":true} -->
139+
140+
```elixir
141+
h Nx.broadcast
142+
```
143+
144+
Much of the time, you won't have to broadcast yourself. Many of
145+
the functions and operators Nx supports will do so automatically.
146+
147+
We can use tensor-aware operators via various `Nx` functions and
148+
many of them implicitly broadcast tensors.
149+
150+
Throughout this section, we have been invoking `Nx.subtract/2` and
151+
our code would be more expressive if we could use its equivalent
152+
mathematical operator. Fortunately, Nx provides a way. Next, we'll
153+
dive into numerical definitions using `defn`.
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Numerical definitions (defn)
2+
3+
The `defn` macro simplifies the expression of mathematical formulas
4+
containing tensors. Numerical definitions have two primary benefits
5+
over classic Elixir functions.
6+
7+
- They are _tensor-aware_. Nx replaces operators like `Kernel.-/2`
8+
with the `Defn` counterparts &mdash; which in turn use `Nx` functions
9+
optimized for tensors &mdash; so the formulas we express can use
10+
tensors out of the box.
11+
12+
- `defn` definitions allow for building computation graph of all the
13+
individual operations and using a just-in-time (JIT) compiler to emit
14+
highly specialized native code for the desired computation unit.
15+
16+
We don't have to do anything special to get access to
17+
get tensor awareness beyond importing `Nx.Defn` and writing
18+
our code within a `defn` block.
19+
20+
To use Nx in a Mix project or a notebook, we need to include
21+
the `:nx` dependency and import the `Nx.Defn` module,
22+
like this:
23+
24+
```elixir
25+
Mix.install([
26+
{:nx, "~> 0.5"}
27+
])
28+
```
29+
30+
```elixir
31+
import Nx.Defn
32+
```
33+
34+
Just as the Elixir language supports `def`, `defmacro`, and `defp`,
35+
Nx supports `defn`. There are a few restrictions. It allows only
36+
numerical arguments in the form of primitives or tensors as arguments
37+
or return values, and supports only a subset of the language.
38+
39+
The subset of Elixir allowed within `defn` is quite broad, though. We can
40+
use macros, pipes, and even conditionals, so we're not giving up
41+
much when you're declaring mathematical functions.
42+
43+
Additionally, despite these small concessions, `defn` provides huge benefits.
44+
Code in a `defn` block uses tensor aware operators and types, so the math
45+
beneath your functions has a better chance to shine through. Numerical
46+
definitions can also run on accelerated numerical processors like GPUs and
47+
TPUs. Here's an example numerical definition:
48+
49+
```elixir
50+
defmodule TensorMath do
51+
import Nx.Defn
52+
53+
defn subtract(a, b) do
54+
a - b
55+
end
56+
end
57+
```
58+
59+
This module has a numerical definition that will be compiled.
60+
If we wanted to specify a compiler for this module, we could add
61+
a module attribute before the `defn` clause. One of such compilers
62+
is [the EXLA compiler](https://github.com/elixir-nx/nx/tree/main/exla).
63+
You'd add the `mix` dependency for EXLA and do this:
64+
65+
<!-- livebook:{"force_markdown":true} -->
66+
67+
```elixir
68+
@defn_compiler EXLA
69+
defn subtract(a, b) do
70+
a - b
71+
end
72+
```
73+
74+
Now, it's your turn. Add a `defn` to `TensorMath`
75+
that accepts two tensors representing the lengths of sides of a
76+
right triangle and uses the pythagorean theorem to return the
77+
[length of the hypotenuse](https://www.mathsisfun.com/pythagoras.html).
78+
Add your function directly to the previous Code cell.
79+
80+
## deftransform
81+
82+
The defn macro in Nx allows you to define functions that compile to efficient
83+
numerical computations, but it comes with certain limitations—such as restrictions
84+
on argument types, return values, and the subset of Elixir that it supports.
85+
To overcome many of these limitations, Nx offers the deftransform macro.
86+
87+
deftransform lets you perform computations or execute code that isn't directly
88+
supported by defn, and then incorporate those results back into your numerical
89+
function. This separation lets you use standard Elixir features where necessary
90+
while keeping your core numerical logic optimized.
91+
92+
In the following example, we define a deftransform function called
93+
compute_tensor_from_list/1 that receives a list, which is not allowed
94+
inside defn. Inside this transform function, we convert the list to a tensor
95+
using Nx.tensor/1, and then pass it to a defn function called double_tensor/1,
96+
which performs the actual numerical computation.
97+
98+
```elixir
99+
defmodule MyMath do
100+
import Nx.Defn
101+
102+
defn double_tensor(tensor) do
103+
tensor * 2
104+
end
105+
106+
deftransform compute_tensor_from_list(list) do
107+
tensor = Nx.tensor(list)
108+
double_tensor(tensor)
109+
end
110+
end
111+
112+
```
113+
114+
```elixir
115+
input = [1, 2, 3, 4]
116+
result = MyMath.compute_tensor_from_list(input)
117+
```
118+
119+
This setup allows us to keep our defn code clean and focused only on tensor
120+
operations, while using deftransform to handle Elixir-native types and
121+
preprocessing.

nx/mix.exs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ defmodule Nx.MixProject do
6161
"guides/getting_started/introduction.md",
6262
"guides/getting_started/installation.md",
6363
"guides/getting_started/quickstart.livemd",
64+
"guides/getting_started/broadcast.livemd",
65+
"guides/getting_started/numerical_definitions.livemd",
6466
"guides/advanced/vectorization.livemd",
6567
"guides/advanced/aggregation.livemd",
6668
"guides/advanced/automatic_differentiation.livemd",

0 commit comments

Comments
 (0)