Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ tempfile = "3"
itertools = "0.14"
syn = { version = "2", features = ["full"] }
quote = "1"
darling = "0.21"
darling = "0.23"
proc-macro2 = "1"
bindgen = "0.72"
cmake = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion mlx-rs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Rust bindings for Apple's mlx machine learning library.

[![Discord](https://img.shields.io/discord/1176807732473495552.svg?color=7289da&&logo=discord)](https://discord.gg/jZvTsxDX49)
[![Current Crates.io Version](https://img.shields.io/crates/v/mlx-rs.svg)](https://crates.io/crates/mlx-rs)
[![Documentation](https://img.shields.io/badge/docs-latest-blue)]()
[![Documentation](https://img.shields.io/badge/docs-latest-blue)](https://oxideai.github.io/mlx-rs/mlx_rs/)
[![Test Status](https://github.com/oxideai/mlx-rs/actions/workflows/validate.yml/badge.svg)](https://github.com/oxideai/mlx-rs/actions/workflows/validate.yml)
[![Blaze](https://runblaze.dev/gh/307493885959233117281096297203102330146/badge.svg)](https://runblaze.dev)
[![Rust Version](https://img.shields.io/badge/Rust-1.82.0+-blue)](https://releases.rs/docs/1.82.0)
Expand Down
4 changes: 4 additions & 0 deletions mlx-rs/src/array/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ impl<'a> TryFrom<&'a Array> for TensorView<'a> {
let data = value.as_slice::<f32>();
cast_slice(data)
},
Dtype::Float64 => {
let data = value.as_slice::<f64>();
cast_slice(data)
},
Dtype::Bfloat16 => {
let data = value.as_slice::<half::bf16>();
let bits: &[u16] = transmute(data);
Expand Down
2 changes: 2 additions & 0 deletions mlx-rs/src/nn/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ pub struct InstanceNorm {
pub eps: Array,

/// An optional trainable weight
#[param]
pub weight: Param<Option<Array>>,

/// An optional trainable bias
#[param]
pub bias: Param<Option<Array>>,
}

Expand Down
3 changes: 3 additions & 0 deletions mlx-rs/src/nn/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ pub struct QuantizedEmbedding {
pub bits: i32,

/// Scales
#[param]
pub scales: Param<Array>,

/// Biases
#[param]
pub biases: Param<Array>,

/// Inner embedding
#[param]
pub inner: Embedding,
}

Expand Down
16 changes: 16 additions & 0 deletions mlx-rs/src/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,22 @@ pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>)
})
}

/// Element-wise inverse tangent of b/a choosing the quadrant correctly.
#[generate_macro]
#[default_device]
pub fn atan2_device(
a: impl AsRef<Array>,
b: impl AsRef<Array>,
#[optional] stream: impl AsRef<Stream>,
) -> Result<Array> {
let a = a.as_ref();
let b = b.as_ref();

Array::try_from_op(|res| unsafe {
mlx_sys::mlx_arctan2(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
})
}

/// Element-wise inverse hyperbolic tangent.
#[generate_macro]
#[default_device]
Expand Down
9 changes: 8 additions & 1 deletion mlx-rs/src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
use crate::{
error::{get_and_clear_closure_error, Result},
module::ModuleParamRef,
utils::{guard::Guarded, Closure, VectorArray},
utils::{guard::Guarded, Closure, VectorArray, SUCCESS},
Array,
};

Expand Down Expand Up @@ -224,7 +224,14 @@
self.c_closure_value_and_grad
}
}

Check warning on line 227 in mlx-rs/src/transforms/mod.rs

View workflow job for this annotation

GitHub Actions / checks

Diff in /Users/admin/actions-runner/_work/mlx-rs/mlx-rs/mlx-rs/src/transforms/mod.rs
impl Drop for ClosureValueAndGrad {
fn drop(&mut self) {
let status = unsafe { mlx_sys::mlx_closure_value_and_grad_free(self.c_closure_value_and_grad) };
debug_assert_eq!(status, SUCCESS);
}
}

fn value_and_gradient(
value_and_grad: mlx_closure_value_and_grad,
arrays: impl Iterator<Item = impl AsRef<Array>>,
Expand Down
4 changes: 3 additions & 1 deletion mlx-rs/src/utils/guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ impl Guard<Vec<Array>> for MaybeUninitVectorArray {
self.init_success = success;
}

fn try_into_guarded(self) -> Result<Vec<Array>, Exception> {
fn try_into_guarded(mut self) -> Result<Vec<Array>, Exception> {
debug_assert!(self.init_success);
self.init_success = false; // mlx_vector_array still needs to be freed after we extracted its elements
unsafe {
let size = mlx_sys::mlx_vector_array_size(self.ptr);
(0..size)
Expand Down
25 changes: 22 additions & 3 deletions mlx-rs/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,11 @@
let payload = raw as *mut std::ffi::c_void;

unsafe {
mlx_sys::mlx_closure_new_func_payload(Some(trampoline::<F>), payload, Some(noop_dtor))
mlx_sys::mlx_closure_new_func_payload(
Some(trampoline::<F>),

Check warning on line 260 in mlx-rs/src/utils/mod.rs

View workflow job for this annotation

GitHub Actions / checks

Diff in /Users/admin/actions-runner/_work/mlx-rs/mlx-rs/mlx-rs/src/utils/mod.rs
payload,
Some(closure_dtor::<F>),
)
}
}

Expand All @@ -272,7 +276,7 @@
mlx_sys::mlx_closure_new_func_payload(
Some(trampoline_fallible::<F>),
payload,
Some(noop_dtor),
Some(closure_dtor::<F>),
)
}
}
Expand Down Expand Up @@ -315,10 +319,13 @@
let arrays = match mlx_vector_array_values(vector_array) {
Ok(arrays) => arrays,
Err(_) => {
let _ = Box::into_raw(closure); // prevent premature drop
return FAILURE;
}
};
let result = closure(&arrays);
let _ = Box::into_raw(closure); // prevent premature drop

// We should probably keep using new_mlx_vector_array here instead of VectorArray
// since we probably don't want to drop the arrays in the closure
*ret = new_mlx_vector_array(result);
Expand All @@ -341,11 +348,14 @@
let arrays = match mlx_vector_array_values(vector_array) {
Ok(arrays) => arrays,
Err(e) => {
let _ = Box::into_raw(closure); // prevent premature drop
set_closure_error(e);
return FAILURE;
}
};
let result = closure(&arrays);
let _ = Box::into_raw(closure); // prevent premature drop

match result {
Ok(result) => {
*ret = new_mlx_vector_array(result);
Expand All @@ -359,7 +369,16 @@
}
}

extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}
// extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}

extern "C" fn closure_dtor<F>(payload: *mut std::ffi::c_void) {
if payload.is_null() {
return;
}
unsafe {
drop(Box::from_raw(payload as *mut F));
}
}

pub(crate) fn get_mut_or_insert_with<'a, T>(
map: &'a mut HashMap<Rc<str>, T>,
Expand Down
Loading