You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: mlx-rs/README.md
+51Lines changed: 51 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -66,6 +66,57 @@ mlx-rs = "0.21.0"
66
66
*`metal` - enables metal (GPU) usage in MLX
67
67
*`accelerate` - enables using the accelerate framework in MLX
68
68
69
+
## Important Notes on Automatic Differentiation
70
+
71
+
When using automatic differentiation in mlx-rs, there's an important difference in how closures work compared to Python's MLX. In Python, variables are implicitly captured and properly traced in the compute graph. However, in Rust, we need to be more explicit about which arrays should be traced.
inputs[0] =&inputs[0] -Array::from_float(learning_rate) *grad; // Update the weight array
114
+
inputs[0].eval()?;
115
+
}
116
+
```
117
+
118
+
We are actively working on improving this API to make it more ergonomic and closer to Python's behavior. For now, explicitly passing all required arrays as shown above is the recommended approach.
119
+
69
120
## Versioning
70
121
71
122
For simplicity, the main crate `mls-rs` follows MLX’s versioning, allowing you to easily see which MLX version you’re using under the hood. The `mlx-sys` crate follows the versioning of `mlx-c`, as that is the version from which the API is generated.
0 commit comments