Skip to content

Commit 7905414

Browse files
committed
docs: improving deftransform example
1 parent 2129801 commit 7905414

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

nx/guides/getting_started/numerical_definitions.livemd

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Numerical definitions (defn)
22

3+
## Section
4+
35
The `defn` macro simplifies the expression of mathematical formulas
46
containing tensors. Numerical definitions have two primary benefits
57
over classic Elixir functions.
68

7-
- They are _tensor-aware_. Nx replaces operators like `Kernel.-/2`
9+
* They are _tensor-aware_. Nx replaces operators like `Kernel.-/2`
810
with the `Defn` counterparts — which in turn use `Nx` functions
911
optimized for tensors — so the formulas we express can use
1012
tensors out of the box.
1113

12-
- `defn` definitions allow for building computation graph of all the
14+
* `defn` definitions allow for building computation graph of all the
1315
individual operations and using a just-in-time (JIT) compiler to emit
1416
highly specialized native code for the desired computation unit.
1517

@@ -99,21 +101,40 @@ which performs the actual numerical computation.
99101
defmodule MyMath do
100102
import Nx.Defn
101103

102-
defn double_tensor(tensor) do
103-
tensor * 2
104+
# Numerical function that just multiplies the tensor by a scalar
105+
defn scale_tensor(tensor) do
106+
Nx.multiply(tensor, 10)
104107
end
105108

106-
deftransform compute_tensor_from_list(list) do
107-
tensor = Nx.tensor(list)
108-
double_tensor(tensor)
109+
# This transform receives a 2D list, validates it, reshapes it,
110+
# adds a new axis, and then passes it to a numerical function.
111+
deftransform compute_from_2d_list(list_2d) do
112+
# Validate that it's a proper matrix (all rows same length)
113+
lengths = Enum.map(list_2d, &length/1)
114+
if Enum.uniq(lengths) != [hd(lengths)] do
115+
raise ArgumentError, "All inner lists must have the same length"
116+
end
117+
118+
# Convert to tensor (e.g., shape {2, 3})
119+
tensor = Nx.tensor(list_2d)
120+
121+
# Add a new axis at the beginning: {2, 3} -> {1, 2, 3}
122+
reshaped = Nx.new_axis(tensor, 0)
123+
124+
# Pass to defn function
125+
scale_tensor(reshaped)
109126
end
110127
end
111128

112129
```
113130

114131
```elixir
115-
input = [1, 2, 3, 4]
116-
result = MyMath.compute_tensor_from_list(input)
132+
matrix = [
133+
[1, 2, 3],
134+
[4, 5, 6]
135+
]
136+
137+
result = MyMath.compute_from_2d_list(matrix)
117138
```
118139

119140
This setup allows us to keep our defn code clean and focused only on tensor

0 commit comments

Comments
 (0)