|
| 1 | +# Numerical definitions (defn) |
| 2 | + |
| 3 | +The `defn` macro simplifies the expression of mathematical formulas |
| 4 | +containing tensors. Numerical definitions have two primary benefits |
| 5 | +over classic Elixir functions. |
| 6 | + |
| 7 | +- They are _tensor-aware_. Nx replaces operators like `Kernel.-/2` |
| 8 | + with the `Defn` counterparts — which in turn use `Nx` functions |
| 9 | + optimized for tensors — so the formulas we express can use |
| 10 | + tensors out of the box. |
| 11 | + |
| 12 | +- `defn` definitions allow for building computation graph of all the |
| 13 | + individual operations and using a just-in-time (JIT) compiler to emit |
| 14 | + highly specialized native code for the desired computation unit. |
| 15 | + |
| 16 | +We don't have to do anything special to get access to |
| 17 | +get tensor awareness beyond importing `Nx.Defn` and writing |
| 18 | +our code within a `defn` block. |
| 19 | + |
| 20 | +To use Nx in a Mix project or a notebook, we need to include |
| 21 | +the `:nx` dependency and import the `Nx.Defn` module, |
| 22 | +like this: |
| 23 | + |
| 24 | +```elixir |
| 25 | +Mix.install([ |
| 26 | + {:nx, "~> 0.5"} |
| 27 | +]) |
| 28 | +``` |
| 29 | + |
| 30 | +```elixir |
| 31 | +import Nx.Defn |
| 32 | +``` |
| 33 | + |
| 34 | +Just as the Elixir language supports `def`, `defmacro`, and `defp`, |
| 35 | +Nx supports `defn`. There are a few restrictions. It allows only |
| 36 | +numerical arguments in the form of primitives or tensors as arguments |
| 37 | +or return values, and supports only a subset of the language. |
| 38 | + |
| 39 | +The subset of Elixir allowed within `defn` is quite broad, though. We can |
| 40 | +use macros, pipes, and even conditionals, so we're not giving up |
| 41 | +much when you're declaring mathematical functions. |
| 42 | + |
| 43 | +Additionally, despite these small concessions, `defn` provides huge benefits. |
| 44 | +Code in a `defn` block uses tensor aware operators and types, so the math |
| 45 | +beneath your functions has a better chance to shine through. Numerical |
| 46 | +definitions can also run on accelerated numerical processors like GPUs and |
| 47 | +TPUs. Here's an example numerical definition: |
| 48 | + |
| 49 | +```elixir |
| 50 | +defmodule TensorMath do |
| 51 | + import Nx.Defn |
| 52 | + |
| 53 | + defn subtract(a, b) do |
| 54 | + a - b |
| 55 | + end |
| 56 | +end |
| 57 | +``` |
| 58 | + |
| 59 | +This module has a numerical definition that will be compiled. |
| 60 | +If we wanted to specify a compiler for this module, we could add |
| 61 | +a module attribute before the `defn` clause. One of such compilers |
| 62 | +is [the EXLA compiler](https://github.com/elixir-nx/nx/tree/main/exla). |
| 63 | +You'd add the `mix` dependency for EXLA and do this: |
| 64 | + |
| 65 | +<!-- livebook:{"force_markdown":true} --> |
| 66 | + |
| 67 | +```elixir |
| 68 | +@defn_compiler EXLA |
| 69 | +defn subtract(a, b) do |
| 70 | + a - b |
| 71 | +end |
| 72 | +``` |
| 73 | + |
| 74 | +Now, it's your turn. Add a `defn` to `TensorMath` |
| 75 | +that accepts two tensors representing the lengths of sides of a |
| 76 | +right triangle and uses the pythagorean theorem to return the |
| 77 | +[length of the hypotenuse](https://www.mathsisfun.com/pythagoras.html). |
| 78 | +Add your function directly to the previous Code cell. |
| 79 | + |
| 80 | +## deftransform |
| 81 | + |
| 82 | +The defn macro in Nx allows you to define functions that compile to efficient |
| 83 | +numerical computations, but it comes with certain limitations—such as restrictions |
| 84 | +on argument types, return values, and the subset of Elixir that it supports. |
| 85 | +To overcome many of these limitations, Nx offers the deftransform macro. |
| 86 | + |
| 87 | +deftransform lets you perform computations or execute code that isn't directly |
| 88 | +supported by defn, and then incorporate those results back into your numerical |
| 89 | +function. This separation lets you use standard Elixir features where necessary |
| 90 | +while keeping your core numerical logic optimized. |
| 91 | + |
| 92 | +In the following example, we define a deftransform function called |
| 93 | +compute_tensor_from_list/1 that receives a list, which is not allowed |
| 94 | +inside defn. Inside this transform function, we convert the list to a tensor |
| 95 | +using Nx.tensor/1, and then pass it to a defn function called double_tensor/1, |
| 96 | +which performs the actual numerical computation. |
| 97 | + |
| 98 | +```elixir |
| 99 | +defmodule MyMath do |
| 100 | + import Nx.Defn |
| 101 | + |
| 102 | + defn double_tensor(tensor) do |
| 103 | + tensor * 2 |
| 104 | + end |
| 105 | + |
| 106 | + deftransform compute_tensor_from_list(list) do |
| 107 | + tensor = Nx.tensor(list) |
| 108 | + double_tensor(tensor) |
| 109 | + end |
| 110 | +end |
| 111 | + |
| 112 | +``` |
| 113 | + |
| 114 | +```elixir |
| 115 | +input = [1, 2, 3, 4] |
| 116 | +result = MyMath.compute_tensor_from_list(input) |
| 117 | +``` |
| 118 | + |
| 119 | +This setup allows us to keep our defn code clean and focused only on tensor |
| 120 | +operations, while using deftransform to handle Elixir-native types and |
| 121 | +preprocessing. |
0 commit comments