Skip to content

Commit cc4258c

Browse files
authored
feat(examples): add linear regression (#199)
1 parent 9d196e5 commit cc4258c

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

mlx-rs/README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,57 @@ mlx-rs = "0.21.0"
6666
* `metal` - enables metal (GPU) usage in MLX
6767
* `accelerate` - enables using the accelerate framework in MLX
6868

69+
## Important Notes on Automatic Differentiation
70+
71+
When using automatic differentiation in mlx-rs, there's an important difference in how closures work compared to Python's MLX. In Python, variables are implicitly captured and properly traced in the compute graph. However, in Rust, we need to be more explicit about which arrays should be traced.
72+
73+
❌ This approach may cause segfaults:
74+
```rust
75+
// Don't do this
76+
let x = random::normal::<f32>(&[num_examples, num_features], None, None, None)?;
77+
let y = x.matmul(&w_star)? + eps;
78+
79+
let loss_fn = |w: &Array| -> Result<Array, Exception> {
80+
let y_pred = x.matmul(w)?; // x and y are captured from outer scope
81+
let loss = Array::from_float(0.5) * ops::mean(&ops::square(&(y_pred - &y))?, None, None)?;
82+
Ok(loss)
83+
};
84+
85+
let grad_fn = transforms::grad(loss_fn, &[0]);
86+
```
87+
88+
✅ Instead, pass all required arrays as inputs to ensure proper tracing:
89+
```rust
90+
let loss_fn = |inputs: &[Array]| -> Result<Array, Exception> {
91+
let w = &inputs[0];
92+
let x = &inputs[1];
93+
let y = &inputs[2];
94+
95+
let y_pred = x.matmul(w)?;
96+
let loss = Array::from_float(0.5) * ops::mean(&ops::square(y_pred - y)?, None, None)?;
97+
Ok(loss)
98+
};
99+
let argnums = &[0]; // Specify which argument to differentiate with respect to
100+
101+
// Pass all required arrays in the inputs slice
102+
let mut inputs = vec![w, x, y];
103+
let grad = transforms::grad(loss_fn, argnums)(&inputs)?;
104+
```
105+
106+
When using gradients in training loops, remember to update the appropriate array in your inputs:
107+
108+
```rust
109+
let mut inputs = vec![w, x, y];
110+
111+
for _ in 0..num_iterations {
112+
let grad = transforms::grad(loss_fn, argnums)(&inputs)?;
113+
inputs[0] = &inputs[0] - Array::from_float(learning_rate) * grad; // Update the weight array
114+
inputs[0].eval()?;
115+
}
116+
```
117+
118+
We are actively working on improving this API to make it more ergonomic and closer to Python's behavior. For now, explicitly passing all required arrays as shown above is the recommended approach.
119+
69120
## Versioning
70121

71122
For simplicity, the main crate `mls-rs` follows MLX’s versioning, allowing you to easily see which MLX version you’re using under the hood. The `mlx-sys` crate follows the versioning of `mlx-c`, as that is the version from which the API is generated.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use mlx_rs::error::Exception;
2+
use mlx_rs::{ops, random, transforms, Array};
3+
use std::error::Error;
4+
5+
fn main() -> Result<(), Box<dyn Error>> {
6+
let num_features: i32 = 100;
7+
let num_examples: i32 = 1000;
8+
let num_iterations: i32 = 10000;
9+
let learning_rate: f32 = 0.01;
10+
11+
// True weight vector
12+
let w_star = random::normal::<f32>(&[num_features], None, None, None)?;
13+
14+
// Input examples (design matrix)
15+
let x = random::normal::<f32>(&[num_examples, num_features], None, None, None)?;
16+
17+
// Noisy labels
18+
let eps = random::normal::<f32>(&[num_examples], None, None, None)? * 1e-2;
19+
let y = x.matmul(&w_star)? + eps;
20+
21+
// Initialize random weights
22+
let w = random::normal::<f32>(&[num_features], None, None, None)? * 1e-2;
23+
24+
let loss_fn = |inputs: &[Array]| -> Result<Array, Exception> {
25+
let w = &inputs[0];
26+
let x = &inputs[1];
27+
let y = &inputs[2];
28+
29+
let y_pred = x.matmul(w)?;
30+
let loss = Array::from_float(0.5) * ops::mean(&ops::square(y_pred - y)?, None, None)?;
31+
Ok(loss)
32+
};
33+
34+
let mut grad_fn = transforms::grad(loss_fn, &[0]);
35+
36+
let now = std::time::Instant::now();
37+
let mut inputs = [w, x, y];
38+
39+
for _ in 0..num_iterations {
40+
let grad = grad_fn(&inputs)?;
41+
inputs[0] = &inputs[0] - Array::from_float(learning_rate) * grad;
42+
inputs[0].eval()?;
43+
}
44+
45+
let elapsed = now.elapsed();
46+
47+
let loss = loss_fn(&inputs)?;
48+
let error_norm = ops::sum(&ops::square(&(&inputs[0] - &w_star))?, None, None)?.sqrt()?;
49+
let throughput = num_iterations as f32 / elapsed.as_secs_f32();
50+
51+
println!(
52+
"Loss {:.5}, L2 distance: |w-w*| = {:.5}, Throughput {:.5} (it/s)",
53+
loss.item::<f32>(),
54+
error_norm.item::<f32>(),
55+
throughput
56+
);
57+
58+
Ok(())
59+
}

0 commit comments

Comments
 (0)