Skip to content

Commit 4c7a699

Browse files
authored
Create agents.md‎ (#330)
1 parent e369861 commit 4c7a699

File tree

1 file changed

+289
-0
lines changed

1 file changed

+289
-0
lines changed

agents.md‎

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Best Practices for Scholar.Optimize Contributions
2+
3+
This document captures key patterns and requirements for implementing optimization algorithms in Scholar, based on review feedback from José Valim (PR #323, #327).
4+
5+
## Core Principle: JIT/GPU Compatibility
6+
7+
**All optimization algorithms must be JIT-compilable and GPU-compatible.**
8+
9+
This means the main entry point must be a `defn` function that can be called with `Nx.Defn.jit_apply/3`.
10+
11+
## Required Patterns
12+
13+
### 1. Use `defn` for Entry Points
14+
15+
```elixir
16+
# GOOD - JIT compatible
17+
defn minimize(a, b, fun, opts \\ []) do
18+
{tol, maxiter} = transform_opts(opts)
19+
minimize_n(a, b, fun, tol, maxiter)
20+
end
21+
22+
# BAD - NOT JIT compatible
23+
deftransform minimize(fun, opts) do
24+
# This prevents JIT compilation!
25+
end
26+
```
27+
28+
### 2. Expose Required Parameters as Function Arguments
29+
30+
Don't bury required parameters in options - expose them as explicit arguments:
31+
32+
```elixir
33+
# GOOD - bounds as explicit args
34+
defn minimize(a, b, fun, opts \\ [])
35+
36+
# BAD - bounds buried in options
37+
deftransform minimize(fun, opts) do
38+
{a, b} = opts[:bracket]
39+
```
40+
41+
**Why?** The `deftransform -> defn` conversion will check the input types automatically, eliminating the need for custom validation logic.
42+
43+
### 3. Use `deftransformp` Only for Option Validation
44+
45+
```elixir
46+
deftransformp transform_opts(opts) do
47+
opts = NimbleOptions.validate!(opts, @opts_schema)
48+
{opts[:tol], opts[:maxiter]}
49+
end
50+
```
51+
52+
### 4. Use `Nx.select` for Branch-Free Conditionals
53+
54+
Never use Elixir runtime conditionals (`if`, `cond`, `case`) inside `defn` functions:
55+
56+
```elixir
57+
# GOOD - tensor-based conditional
58+
new_a = Nx.select(condition, value_if_true, value_if_false)
59+
60+
# BAD - Elixir runtime conditional
61+
new_a = if Nx.to_number(condition) == 1, do: value_if_true, else: value_if_false
62+
```
63+
64+
For complex multi-way conditionals, use nested `Nx.select`:
65+
66+
```elixir
67+
# Four-way conditional
68+
result = Nx.select(
69+
cond1,
70+
value1,
71+
Nx.select(
72+
cond2,
73+
value2,
74+
Nx.select(cond3, value3, value4)
75+
)
76+
)
77+
```
78+
79+
### 5. Use `while` Loop (Not Recursion)
80+
81+
```elixir
82+
# GOOD - while loop
83+
{final_state, _} =
84+
while {state = initial_state, {tol, maxiter}},
85+
state.iter < maxiter and state.b - state.a >= tol do
86+
# Update state
87+
new_state = %{state | iter: state.iter + 1, ...}
88+
{new_state, {tol, maxiter}}
89+
end
90+
91+
# BAD - recursive call (not JIT-compatible)
92+
defp loop(fun, state, tol, maxiter) do
93+
if converged?(state) do
94+
state
95+
else
96+
loop(fun, update(state), tol, maxiter)
97+
end
98+
end
99+
```
100+
101+
### 6. Never Use `Nx.to_number` in `defn`
102+
103+
All computation must stay as tensors for JIT compilation:
104+
105+
```elixir
106+
# GOOD - all tensor operations
107+
converged = state.b - state.a < tol
108+
109+
# BAD - converts to Elixir number
110+
converged = Nx.to_number(state.b) - Nx.to_number(state.a) < Nx.to_number(tol)
111+
```
112+
113+
### 7. Use Unsigned Types for Non-Negative Counters
114+
115+
```elixir
116+
initial_state = %{
117+
iter: Nx.u32(0), # u32 for iteration count
118+
f_evals: Nx.u32(2) # u32 for function evaluation count
119+
}
120+
```
121+
122+
### 8. Let Users Control Precision via Input Types
123+
124+
Don't force type conversions - let the tensor type propagate from inputs:
125+
126+
```elixir
127+
# GOOD - let user decide precision
128+
defn minimize(a, b, fun, opts \\ []) do
129+
# a and b types propagate through computation
130+
end
131+
132+
# BAD - forcing f64
133+
a = Nx.tensor(a, type: :f64)
134+
```
135+
136+
Document that users can use f64 tensors for higher precision:
137+
138+
```elixir
139+
@doc """
140+
For higher precision, use f64 tensors for bounds:
141+
142+
a = Nx.tensor(0.0, type: :f64)
143+
b = Nx.tensor(5.0, type: :f64)
144+
result = Brent.minimize(a, b, fun, tol: 1.0e-10)
145+
"""
146+
```
147+
148+
### 9. Use Module Constants Directly
149+
150+
```elixir
151+
# GOOD - use @attr directly in defn
152+
@phi 0.6180339887498949
153+
154+
defnp minimize_n(a, b, fun, tol, maxiter) do
155+
c = b - @phi * (b - a)
156+
end
157+
158+
# BAD - wrapping in tensor
159+
defnp minimize_n(a, b, fun, tol, maxiter) do
160+
phi = Nx.tensor(@phi)
161+
c = b - phi * (b - a)
162+
end
163+
```
164+
165+
### 10. Self-Contain Modules
166+
167+
Keep NimbleOptions validation in the same module - don't create wrapper modules:
168+
169+
```elixir
170+
defmodule Scholar.Optimize.Brent do
171+
opts = [
172+
tol: [...],
173+
maxiter: [...]
174+
]
175+
176+
@opts_schema NimbleOptions.new!(opts)
177+
178+
# Validation happens here, not in a separate module
179+
end
180+
```
181+
182+
## Module Structure Template
183+
184+
```elixir
185+
defmodule Scholar.Optimize.AlgorithmName do
186+
@moduledoc """
187+
Description of the algorithm.
188+
189+
## Algorithm
190+
...
191+
192+
## Convergence
193+
...
194+
195+
## References
196+
...
197+
"""
198+
199+
import Nx.Defn
200+
201+
@derive {Nx.Container, containers: [:x, :fun, :converged, :iterations, :fun_evals]}
202+
defstruct [:x, :fun, :converged, :iterations, :fun_evals]
203+
204+
@type t :: %__MODULE__{
205+
x: Nx.Tensor.t(),
206+
fun: Nx.Tensor.t(),
207+
converged: Nx.Tensor.t(),
208+
iterations: Nx.Tensor.t(),
209+
fun_evals: Nx.Tensor.t()
210+
}
211+
212+
# Constants
213+
@some_constant 0.123456789
214+
215+
# Options schema
216+
opts = [
217+
tol: [
218+
type: {:custom, Scholar.Options, :positive_number, []},
219+
default: 1.0e-5,
220+
doc: "..."
221+
],
222+
maxiter: [
223+
type: :pos_integer,
224+
default: 500,
225+
doc: "..."
226+
]
227+
]
228+
229+
@opts_schema NimbleOptions.new!(opts)
230+
231+
@doc """
232+
Main entry point documentation...
233+
"""
234+
defn minimize(a, b, fun, opts \\ []) do
235+
{tol, maxiter} = transform_opts(opts)
236+
minimize_n(a, b, fun, tol, maxiter)
237+
end
238+
239+
deftransformp transform_opts(opts) do
240+
opts = NimbleOptions.validate!(opts, @opts_schema)
241+
{opts[:tol], opts[:maxiter]}
242+
end
243+
244+
defnp minimize_n(a, b, fun, tol, maxiter) do
245+
# Implementation using while loop and Nx.select
246+
end
247+
end
248+
```
249+
250+
## Test Requirements
251+
252+
Every optimization module must include:
253+
254+
1. **Basic functionality tests** - Verify correct results on standard test functions
255+
2. **Option handling tests** - Test tolerance and maxiter options
256+
3. **JIT compatibility test** - Critical! Must pass:
257+
258+
```elixir
259+
test "works with jit_apply" do
260+
fun = fn x -> Nx.pow(Nx.subtract(x, 3), 2) end
261+
opts = [tol: 1.0e-5, maxiter: 500]
262+
263+
result = Nx.Defn.jit_apply(&AlgorithmName.minimize/4, [0.0, 5.0, fun, opts])
264+
265+
assert Nx.to_number(result.converged) == 1
266+
end
267+
```
268+
269+
4. **Tensor bounds test** - Accept both numbers and tensors
270+
5. **Precision test** - Higher precision with f64 bounds
271+
272+
## Validation Against SciPy
273+
274+
When implementing algorithms, validate results against SciPy:
275+
276+
```python
277+
from scipy.optimize import minimize_scalar
278+
279+
result = minimize_scalar(func, bracket=(a, b), method='brent')
280+
print(f"x: {result.x}, f(x): {result.fun}, iterations: {result.nit}")
281+
```
282+
283+
Use these reference values in tests with appropriate tolerance (typically `atol: 1.0e-4` to `1.0e-6`).
284+
285+
## References
286+
287+
- PR #323: https://github.com/elixir-nx/scholar/pull/323 (original comprehensive optimizer)
288+
- PR #327: https://github.com/elixir-nx/scholar/pull/327 (merged Golden Section)
289+
- SciPy optimize: https://docs.scipy.org/doc/scipy/tutorial/optimize.html

0 commit comments

Comments
 (0)