|
| 1 | +# Best Practices for Scholar.Optimize Contributions |
| 2 | + |
| 3 | +This document captures key patterns and requirements for implementing optimization algorithms in Scholar, based on review feedback from José Valim (PR #323, #327). |
| 4 | + |
| 5 | +## Core Principle: JIT/GPU Compatibility |
| 6 | + |
| 7 | +**All optimization algorithms must be JIT-compilable and GPU-compatible.** |
| 8 | + |
| 9 | +This means the main entry point must be a `defn` function that can be called with `Nx.Defn.jit_apply/3`. |
| 10 | + |
| 11 | +## Required Patterns |
| 12 | + |
| 13 | +### 1. Use `defn` for Entry Points |
| 14 | + |
| 15 | +```elixir |
| 16 | +# GOOD - JIT compatible |
| 17 | +defn minimize(a, b, fun, opts \\ []) do |
| 18 | + {tol, maxiter} = transform_opts(opts) |
| 19 | + minimize_n(a, b, fun, tol, maxiter) |
| 20 | +end |
| 21 | + |
| 22 | +# BAD - NOT JIT compatible |
| 23 | +deftransform minimize(fun, opts) do |
| 24 | + # This prevents JIT compilation! |
| 25 | +end |
| 26 | +``` |
| 27 | + |
| 28 | +### 2. Expose Required Parameters as Function Arguments |
| 29 | + |
| 30 | +Don't bury required parameters in options - expose them as explicit arguments: |
| 31 | + |
| 32 | +```elixir |
| 33 | +# GOOD - bounds as explicit args |
| 34 | +defn minimize(a, b, fun, opts \\ []) |
| 35 | + |
| 36 | +# BAD - bounds buried in options |
| 37 | +deftransform minimize(fun, opts) do |
| 38 | + {a, b} = opts[:bracket] |
| 39 | +``` |
| 40 | + |
| 41 | +**Why?** The `deftransform -> defn` conversion will check the input types automatically, eliminating the need for custom validation logic. |
| 42 | + |
| 43 | +### 3. Use `deftransformp` Only for Option Validation |
| 44 | + |
| 45 | +```elixir |
| 46 | +deftransformp transform_opts(opts) do |
| 47 | + opts = NimbleOptions.validate!(opts, @opts_schema) |
| 48 | + {opts[:tol], opts[:maxiter]} |
| 49 | +end |
| 50 | +``` |
| 51 | + |
| 52 | +### 4. Use `Nx.select` for Branch-Free Conditionals |
| 53 | + |
| 54 | +Never use Elixir runtime conditionals (`if`, `cond`, `case`) inside `defn` functions: |
| 55 | + |
| 56 | +```elixir |
| 57 | +# GOOD - tensor-based conditional |
| 58 | +new_a = Nx.select(condition, value_if_true, value_if_false) |
| 59 | + |
| 60 | +# BAD - Elixir runtime conditional |
| 61 | +new_a = if Nx.to_number(condition) == 1, do: value_if_true, else: value_if_false |
| 62 | +``` |
| 63 | + |
| 64 | +For complex multi-way conditionals, use nested `Nx.select`: |
| 65 | + |
| 66 | +```elixir |
| 67 | +# Four-way conditional |
| 68 | +result = Nx.select( |
| 69 | + cond1, |
| 70 | + value1, |
| 71 | + Nx.select( |
| 72 | + cond2, |
| 73 | + value2, |
| 74 | + Nx.select(cond3, value3, value4) |
| 75 | + ) |
| 76 | +) |
| 77 | +``` |
| 78 | + |
| 79 | +### 5. Use `while` Loop (Not Recursion) |
| 80 | + |
| 81 | +```elixir |
| 82 | +# GOOD - while loop |
| 83 | +{final_state, _} = |
| 84 | + while {state = initial_state, {tol, maxiter}}, |
| 85 | + state.iter < maxiter and state.b - state.a >= tol do |
| 86 | + # Update state |
| 87 | + new_state = %{state | iter: state.iter + 1, ...} |
| 88 | + {new_state, {tol, maxiter}} |
| 89 | + end |
| 90 | + |
| 91 | +# BAD - recursive call (not JIT-compatible) |
| 92 | +defp loop(fun, state, tol, maxiter) do |
| 93 | + if converged?(state) do |
| 94 | + state |
| 95 | + else |
| 96 | + loop(fun, update(state), tol, maxiter) |
| 97 | + end |
| 98 | +end |
| 99 | +``` |
| 100 | + |
| 101 | +### 6. Never Use `Nx.to_number` in `defn` |
| 102 | + |
| 103 | +All computation must stay as tensors for JIT compilation: |
| 104 | + |
| 105 | +```elixir |
| 106 | +# GOOD - all tensor operations |
| 107 | +converged = state.b - state.a < tol |
| 108 | + |
| 109 | +# BAD - converts to Elixir number |
| 110 | +converged = Nx.to_number(state.b) - Nx.to_number(state.a) < Nx.to_number(tol) |
| 111 | +``` |
| 112 | + |
| 113 | +### 7. Use Unsigned Types for Non-Negative Counters |
| 114 | + |
| 115 | +```elixir |
| 116 | +initial_state = %{ |
| 117 | + iter: Nx.u32(0), # u32 for iteration count |
| 118 | + f_evals: Nx.u32(2) # u32 for function evaluation count |
| 119 | +} |
| 120 | +``` |
| 121 | + |
| 122 | +### 8. Let Users Control Precision via Input Types |
| 123 | + |
| 124 | +Don't force type conversions - let the tensor type propagate from inputs: |
| 125 | + |
| 126 | +```elixir |
| 127 | +# GOOD - let user decide precision |
| 128 | +defn minimize(a, b, fun, opts \\ []) do |
| 129 | + # a and b types propagate through computation |
| 130 | +end |
| 131 | + |
| 132 | +# BAD - forcing f64 |
| 133 | +a = Nx.tensor(a, type: :f64) |
| 134 | +``` |
| 135 | + |
| 136 | +Document that users can use f64 tensors for higher precision: |
| 137 | + |
| 138 | +```elixir |
| 139 | +@doc """ |
| 140 | +For higher precision, use f64 tensors for bounds: |
| 141 | + |
| 142 | + a = Nx.tensor(0.0, type: :f64) |
| 143 | + b = Nx.tensor(5.0, type: :f64) |
| 144 | + result = Brent.minimize(a, b, fun, tol: 1.0e-10) |
| 145 | +""" |
| 146 | +``` |
| 147 | + |
| 148 | +### 9. Use Module Constants Directly |
| 149 | + |
| 150 | +```elixir |
| 151 | +# GOOD - use @attr directly in defn |
| 152 | +@phi 0.6180339887498949 |
| 153 | + |
| 154 | +defnp minimize_n(a, b, fun, tol, maxiter) do |
| 155 | + c = b - @phi * (b - a) |
| 156 | +end |
| 157 | + |
| 158 | +# BAD - wrapping in tensor |
| 159 | +defnp minimize_n(a, b, fun, tol, maxiter) do |
| 160 | + phi = Nx.tensor(@phi) |
| 161 | + c = b - phi * (b - a) |
| 162 | +end |
| 163 | +``` |
| 164 | + |
| 165 | +### 10. Self-Contain Modules |
| 166 | + |
| 167 | +Keep NimbleOptions validation in the same module - don't create wrapper modules: |
| 168 | + |
| 169 | +```elixir |
| 170 | +defmodule Scholar.Optimize.Brent do |
| 171 | + opts = [ |
| 172 | + tol: [...], |
| 173 | + maxiter: [...] |
| 174 | + ] |
| 175 | + |
| 176 | + @opts_schema NimbleOptions.new!(opts) |
| 177 | + |
| 178 | + # Validation happens here, not in a separate module |
| 179 | +end |
| 180 | +``` |
| 181 | + |
| 182 | +## Module Structure Template |
| 183 | + |
| 184 | +```elixir |
| 185 | +defmodule Scholar.Optimize.AlgorithmName do |
| 186 | + @moduledoc """ |
| 187 | + Description of the algorithm. |
| 188 | + |
| 189 | + ## Algorithm |
| 190 | + ... |
| 191 | + |
| 192 | + ## Convergence |
| 193 | + ... |
| 194 | + |
| 195 | + ## References |
| 196 | + ... |
| 197 | + """ |
| 198 | + |
| 199 | + import Nx.Defn |
| 200 | + |
| 201 | + @derive {Nx.Container, containers: [:x, :fun, :converged, :iterations, :fun_evals]} |
| 202 | + defstruct [:x, :fun, :converged, :iterations, :fun_evals] |
| 203 | + |
| 204 | + @type t :: %__MODULE__{ |
| 205 | + x: Nx.Tensor.t(), |
| 206 | + fun: Nx.Tensor.t(), |
| 207 | + converged: Nx.Tensor.t(), |
| 208 | + iterations: Nx.Tensor.t(), |
| 209 | + fun_evals: Nx.Tensor.t() |
| 210 | + } |
| 211 | + |
| 212 | + # Constants |
| 213 | + @some_constant 0.123456789 |
| 214 | + |
| 215 | + # Options schema |
| 216 | + opts = [ |
| 217 | + tol: [ |
| 218 | + type: {:custom, Scholar.Options, :positive_number, []}, |
| 219 | + default: 1.0e-5, |
| 220 | + doc: "..." |
| 221 | + ], |
| 222 | + maxiter: [ |
| 223 | + type: :pos_integer, |
| 224 | + default: 500, |
| 225 | + doc: "..." |
| 226 | + ] |
| 227 | + ] |
| 228 | + |
| 229 | + @opts_schema NimbleOptions.new!(opts) |
| 230 | + |
| 231 | + @doc """ |
| 232 | + Main entry point documentation... |
| 233 | + """ |
| 234 | + defn minimize(a, b, fun, opts \\ []) do |
| 235 | + {tol, maxiter} = transform_opts(opts) |
| 236 | + minimize_n(a, b, fun, tol, maxiter) |
| 237 | + end |
| 238 | + |
| 239 | + deftransformp transform_opts(opts) do |
| 240 | + opts = NimbleOptions.validate!(opts, @opts_schema) |
| 241 | + {opts[:tol], opts[:maxiter]} |
| 242 | + end |
| 243 | + |
| 244 | + defnp minimize_n(a, b, fun, tol, maxiter) do |
| 245 | + # Implementation using while loop and Nx.select |
| 246 | + end |
| 247 | +end |
| 248 | +``` |
| 249 | + |
| 250 | +## Test Requirements |
| 251 | + |
| 252 | +Every optimization module must include: |
| 253 | + |
| 254 | +1. **Basic functionality tests** - Verify correct results on standard test functions |
| 255 | +2. **Option handling tests** - Test tolerance and maxiter options |
| 256 | +3. **JIT compatibility test** - Critical! Must pass: |
| 257 | + |
| 258 | +```elixir |
| 259 | +test "works with jit_apply" do |
| 260 | + fun = fn x -> Nx.pow(Nx.subtract(x, 3), 2) end |
| 261 | + opts = [tol: 1.0e-5, maxiter: 500] |
| 262 | + |
| 263 | + result = Nx.Defn.jit_apply(&AlgorithmName.minimize/4, [0.0, 5.0, fun, opts]) |
| 264 | + |
| 265 | + assert Nx.to_number(result.converged) == 1 |
| 266 | +end |
| 267 | +``` |
| 268 | + |
| 269 | +4. **Tensor bounds test** - Accept both numbers and tensors |
| 270 | +5. **Precision test** - Higher precision with f64 bounds |
| 271 | + |
| 272 | +## Validation Against SciPy |
| 273 | + |
| 274 | +When implementing algorithms, validate results against SciPy: |
| 275 | + |
| 276 | +```python |
| 277 | +from scipy.optimize import minimize_scalar |
| 278 | + |
| 279 | +result = minimize_scalar(func, bracket=(a, b), method='brent') |
| 280 | +print(f"x: {result.x}, f(x): {result.fun}, iterations: {result.nit}") |
| 281 | +``` |
| 282 | + |
| 283 | +Use these reference values in tests with appropriate tolerance (typically `atol: 1.0e-4` to `1.0e-6`). |
| 284 | + |
| 285 | +## References |
| 286 | + |
| 287 | +- PR #323: https://github.com/elixir-nx/scholar/pull/323 (original comprehensive optimizer) |
| 288 | +- PR #327: https://github.com/elixir-nx/scholar/pull/327 (merged Golden Section) |
| 289 | +- SciPy optimize: https://docs.scipy.org/doc/scipy/tutorial/optimize.html |
0 commit comments