|
1 | 1 | # Numerical definitions (defn) |
2 | 2 |
|
| 3 | +## Section |
| 4 | + |
3 | 5 | The `defn` macro simplifies the expression of mathematical formulas |
4 | 6 | containing tensors. Numerical definitions have two primary benefits |
5 | 7 | over classic Elixir functions. |
6 | 8 |
|
7 | | -- They are _tensor-aware_. Nx replaces operators like `Kernel.-/2` |
| 9 | +* They are _tensor-aware_. Nx replaces operators like `Kernel.-/2` |
8 | 10 | with the `Defn` counterparts — which in turn use `Nx` functions |
9 | 11 | optimized for tensors — so the formulas we express can use |
10 | 12 | tensors out of the box. |
11 | 13 |
|
12 | | -- `defn` definitions allow for building computation graph of all the |
| 14 | +* `defn` definitions allow for building computation graph of all the |
13 | 15 | individual operations and using a just-in-time (JIT) compiler to emit |
14 | 16 | highly specialized native code for the desired computation unit. |
15 | 17 |
|
@@ -99,21 +101,40 @@ which performs the actual numerical computation. |
99 | 101 | defmodule MyMath do |
100 | 102 | import Nx.Defn |
101 | 103 |
|
102 | | - defn double_tensor(tensor) do |
103 | | - tensor * 2 |
| 104 | + # Numerical function that just multiplies the tensor by a scalar |
| 105 | + defn scale_tensor(tensor) do |
| 106 | + Nx.multiply(tensor, 10) |
104 | 107 | end |
105 | 108 |
|
106 | | - deftransform compute_tensor_from_list(list) do |
107 | | - tensor = Nx.tensor(list) |
108 | | - double_tensor(tensor) |
| 109 | + # This transform receives a 2D list, validates it, reshapes it, |
| 110 | + # adds a new axis, and then passes it to a numerical function. |
| 111 | + deftransform compute_from_2d_list(list_2d) do |
| 112 | + # Validate that it's a proper matrix (all rows same length) |
| 113 | + lengths = Enum.map(list_2d, &length/1) |
| 114 | + if Enum.uniq(lengths) != [hd(lengths)] do |
| 115 | + raise ArgumentError, "All inner lists must have the same length" |
| 116 | + end |
| 117 | + |
| 118 | + # Convert to tensor (e.g., shape {2, 3}) |
| 119 | + tensor = Nx.tensor(list_2d) |
| 120 | + |
| 121 | + # Add a new axis at the beginning: {2, 3} -> {1, 2, 3} |
| 122 | + reshaped = Nx.new_axis(tensor, 0) |
| 123 | + |
| 124 | + # Pass to defn function |
| 125 | + scale_tensor(reshaped) |
109 | 126 | end |
110 | 127 | end |
111 | 128 |
|
112 | 129 | ``` |
113 | 130 |
|
114 | 131 | ```elixir |
115 | | -input = [1, 2, 3, 4] |
116 | | -result = MyMath.compute_tensor_from_list(input) |
| 132 | +matrix = [ |
| 133 | + [1, 2, 3], |
| 134 | + [4, 5, 6] |
| 135 | +] |
| 136 | + |
| 137 | +result = MyMath.compute_from_2d_list(matrix) |
117 | 138 | ``` |
118 | 139 |
|
119 | 140 | This setup allows us to keep our defn code clean and focused only on tensor |
|
0 commit comments