diff --git a/Cargo.toml b/Cargo.toml index a10c60bec..b3306a1e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ tempfile = "3" itertools = "0.14" syn = { version = "2", features = ["full"] } quote = "1" -darling = "0.21" +darling = "0.23" proc-macro2 = "1" bindgen = "0.72" cmake = "0.1" diff --git a/mlx-rs/README.md b/mlx-rs/README.md index e21b3c962..3ed5f757c 100644 --- a/mlx-rs/README.md +++ b/mlx-rs/README.md @@ -5,7 +5,7 @@ Rust bindings for Apple's mlx machine learning library. [![Discord](https://img.shields.io/discord/1176807732473495552.svg?color=7289da&&logo=discord)](https://discord.gg/jZvTsxDX49) [![Current Crates.io Version](https://img.shields.io/crates/v/mlx-rs.svg)](https://crates.io/crates/mlx-rs) -[![Documentation](https://img.shields.io/badge/docs-latest-blue)]() +[![Documentation](https://img.shields.io/badge/docs-latest-blue)](https://oxideai.github.io/mlx-rs/mlx_rs/) [![Test Status](https://github.com/oxideai/mlx-rs/actions/workflows/validate.yml/badge.svg)](https://github.com/oxideai/mlx-rs/actions/workflows/validate.yml) [![Blaze](https://runblaze.dev/gh/307493885959233117281096297203102330146/badge.svg)](https://runblaze.dev) [![Rust Version](https://img.shields.io/badge/Rust-1.82.0+-blue)](https://releases.rs/docs/1.82.0) diff --git a/mlx-rs/src/array/safetensors.rs b/mlx-rs/src/array/safetensors.rs index bd148102c..2340f627c 100644 --- a/mlx-rs/src/array/safetensors.rs +++ b/mlx-rs/src/array/safetensors.rs @@ -85,6 +85,10 @@ impl<'a> TryFrom<&'a Array> for TensorView<'a> { let data = value.as_slice::(); cast_slice(data) }, + Dtype::Float64 => { + let data = value.as_slice::(); + cast_slice(data) + }, Dtype::Bfloat16 => { let data = value.as_slice::(); let bits: &[u16] = transmute(data); diff --git a/mlx-rs/src/nn/normalization.rs b/mlx-rs/src/nn/normalization.rs index 41360cc45..f351c1f10 100644 --- a/mlx-rs/src/nn/normalization.rs +++ b/mlx-rs/src/nn/normalization.rs @@ -80,9 +80,11 @@ pub struct InstanceNorm { pub eps: Array, /// An optional trainable weight + #[param] pub weight: Param>, /// An optional trainable bias + #[param] pub bias: Param>, } diff --git a/mlx-rs/src/nn/quantized.rs b/mlx-rs/src/nn/quantized.rs index ed7f487eb..609b3a08b 100644 --- a/mlx-rs/src/nn/quantized.rs +++ b/mlx-rs/src/nn/quantized.rs @@ -70,12 +70,15 @@ pub struct QuantizedEmbedding { pub bits: i32, /// Scales + #[param] pub scales: Param, /// Biases + #[param] pub biases: Param, /// Inner embedding + #[param] pub inner: Embedding, } diff --git a/mlx-rs/src/ops/arithmetic.rs b/mlx-rs/src/ops/arithmetic.rs index a556ef273..0ee5a6265 100644 --- a/mlx-rs/src/ops/arithmetic.rs +++ b/mlx-rs/src/ops/arithmetic.rs @@ -742,6 +742,22 @@ pub fn atan_device(a: impl AsRef, #[optional] stream: impl AsRef) }) } +/// Element-wise inverse tangent of b/a choosing the quadrant correctly. +#[generate_macro] +#[default_device] +pub fn atan2_device( + a: impl AsRef, + b: impl AsRef, + #[optional] stream: impl AsRef, +) -> Result { + let a = a.as_ref(); + let b = b.as_ref(); + + Array::try_from_op(|res| unsafe { + mlx_sys::mlx_arctan2(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr()) + }) +} + /// Element-wise inverse hyperbolic tangent. #[generate_macro] #[default_device] diff --git a/mlx-rs/src/transforms/mod.rs b/mlx-rs/src/transforms/mod.rs index 8ba0f201d..3a256b23d 100644 --- a/mlx-rs/src/transforms/mod.rs +++ b/mlx-rs/src/transforms/mod.rs @@ -50,7 +50,7 @@ use mlx_sys::mlx_closure_value_and_grad; use crate::{ error::{get_and_clear_closure_error, Result}, module::ModuleParamRef, - utils::{guard::Guarded, Closure, VectorArray}, + utils::{guard::Guarded, Closure, VectorArray, SUCCESS}, Array, }; @@ -225,6 +225,13 @@ impl ClosureValueAndGrad { } } +impl Drop for ClosureValueAndGrad { + fn drop(&mut self) { + let status = unsafe { mlx_sys::mlx_closure_value_and_grad_free(self.c_closure_value_and_grad) }; + debug_assert_eq!(status, SUCCESS); + } +} + fn value_and_gradient( value_and_grad: mlx_closure_value_and_grad, arrays: impl Iterator>, diff --git a/mlx-rs/src/utils/guard.rs b/mlx-rs/src/utils/guard.rs index 661b602ce..ade865a3f 100644 --- a/mlx-rs/src/utils/guard.rs +++ b/mlx-rs/src/utils/guard.rs @@ -144,7 +144,9 @@ impl Guard> for MaybeUninitVectorArray { self.init_success = success; } - fn try_into_guarded(self) -> Result, Exception> { + fn try_into_guarded(mut self) -> Result, Exception> { + debug_assert!(self.init_success); + self.init_success = false; // mlx_vector_array still needs to be freed after we extracted its elements unsafe { let size = mlx_sys::mlx_vector_array_size(self.ptr); (0..size) diff --git a/mlx-rs/src/utils/mod.rs b/mlx-rs/src/utils/mod.rs index 25fce6a5e..e42a9d196 100644 --- a/mlx-rs/src/utils/mod.rs +++ b/mlx-rs/src/utils/mod.rs @@ -256,7 +256,11 @@ where let payload = raw as *mut std::ffi::c_void; unsafe { - mlx_sys::mlx_closure_new_func_payload(Some(trampoline::), payload, Some(noop_dtor)) + mlx_sys::mlx_closure_new_func_payload( + Some(trampoline::), + payload, + Some(closure_dtor::), + ) } } @@ -272,7 +276,7 @@ where mlx_sys::mlx_closure_new_func_payload( Some(trampoline_fallible::), payload, - Some(noop_dtor), + Some(closure_dtor::), ) } } @@ -315,10 +319,13 @@ where let arrays = match mlx_vector_array_values(vector_array) { Ok(arrays) => arrays, Err(_) => { + let _ = Box::into_raw(closure); // prevent premature drop return FAILURE; } }; let result = closure(&arrays); + let _ = Box::into_raw(closure); // prevent premature drop + // We should probably keep using new_mlx_vector_array here instead of VectorArray // since we probably don't want to drop the arrays in the closure *ret = new_mlx_vector_array(result); @@ -341,11 +348,14 @@ where let arrays = match mlx_vector_array_values(vector_array) { Ok(arrays) => arrays, Err(e) => { + let _ = Box::into_raw(closure); // prevent premature drop set_closure_error(e); return FAILURE; } }; let result = closure(&arrays); + let _ = Box::into_raw(closure); // prevent premature drop + match result { Ok(result) => { *ret = new_mlx_vector_array(result); @@ -359,7 +369,16 @@ where } } -extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {} +// extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {} + +extern "C" fn closure_dtor(payload: *mut std::ffi::c_void) { + if payload.is_null() { + return; + } + unsafe { + drop(Box::from_raw(payload as *mut F)); + } +} pub(crate) fn get_mut_or_insert_with<'a, T>( map: &'a mut HashMap, T>,