diff --git a/Cargo.toml b/Cargo.toml index d41185e4a..706868cca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,8 +12,9 @@ license = "MIT" readme = "README.md" [dependencies] -mlx-sys = { version = "0.10.0-alpha.0", path = "mlx-sys" } +derive-new = "0.6.0" half = "2" +mlx-sys = { version = "0.10.0-alpha.0", path = "mlx-sys" } num-complex = "0.4" num_enum = "0.7.2" diff --git a/src/array/kind.rs b/src/array/kind.rs new file mode 100644 index 000000000..84a6aafd0 --- /dev/null +++ b/src/array/kind.rs @@ -0,0 +1,226 @@ +use crate::array::wrapper::Array; +use crate::sealed::Sealed; +use num_complex::{Complex, Complex32}; + +/// Array element type +#[derive( + Debug, Clone, Copy, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::TryFromPrimitive, +)] +#[repr(u32)] +pub enum Kind { + Bool = mlx_sys::mlx_array_dtype__MLX_BOOL, + Uint8 = mlx_sys::mlx_array_dtype__MLX_UINT8, + Uint16 = mlx_sys::mlx_array_dtype__MLX_UINT16, + Uint32 = mlx_sys::mlx_array_dtype__MLX_UINT32, + Uint64 = mlx_sys::mlx_array_dtype__MLX_UINT64, + Int8 = mlx_sys::mlx_array_dtype__MLX_INT8, + Int16 = mlx_sys::mlx_array_dtype__MLX_INT16, + Int32 = mlx_sys::mlx_array_dtype__MLX_INT32, + Int64 = mlx_sys::mlx_array_dtype__MLX_INT64, + Float16 = mlx_sys::mlx_array_dtype__MLX_FLOAT16, + Float32 = mlx_sys::mlx_array_dtype__MLX_FLOAT32, + Bfloat16 = mlx_sys::mlx_array_dtype__MLX_BFLOAT16, + Complex64 = mlx_sys::mlx_array_dtype__MLX_COMPLEX64, +} + +/// Kinds for tensor elements +/// +/// # Safety +/// The specified Kind must be for a type that has the same length as Self. +pub unsafe trait Element: Clone { + const KIND: Kind; + const ZERO: Self; + + fn array_item(array: &Array) -> Self; + fn array_data(array: &Array) -> *const Self; +} + +impl Sealed for bool {} +unsafe impl Element for bool { + const KIND: Kind = Kind::Bool; + const ZERO: Self = false; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_bool(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_bool(array.c_array) } + } +} + +impl Sealed for u8 {} +unsafe impl Element for u8 { + const KIND: Kind = Kind::Uint8; + const ZERO: Self = 0; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_uint8(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_uint8(array.c_array) } + } +} + +impl Sealed for u16 {} +unsafe impl Element for u16 { + const KIND: Kind = Kind::Uint16; + const ZERO: Self = 0; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_uint16(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_uint16(array.c_array) } + } +} + +impl Sealed for u32 {} +unsafe impl Element for u32 { + const KIND: Kind = Kind::Uint32; + const ZERO: Self = 0; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_uint32(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_uint32(array.c_array) } + } +} + +impl Sealed for u64 {} +unsafe impl Element for u64 { + const KIND: Kind = Kind::Uint64; + const ZERO: Self = 0; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_uint64(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_uint64(array.c_array) } + } +} + +impl Sealed for i8 {} +unsafe impl Element for i8 { + const KIND: Kind = Kind::Int8; + const ZERO: Self = 0; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_int8(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_int8(array.c_array) } + } +} + +impl Sealed for i16 {} +unsafe impl Element for i16 { + const KIND: Kind = Kind::Int16; + const ZERO: Self = 0; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_int16(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_int16(array.c_array) } + } +} + +impl Sealed for i32 {} +unsafe impl Element for i32 { + const KIND: Kind = Kind::Int32; + const ZERO: Self = 0; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_int32(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_int32(array.c_array) } + } +} + +impl Sealed for i64 {} +unsafe impl Element for i64 { + const KIND: Kind = Kind::Int64; + const ZERO: Self = 0; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_int64(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_int64(array.c_array) } + } +} + +impl Sealed for f32 {} +unsafe impl Element for f32 { + const KIND: Kind = Kind::Float32; + const ZERO: Self = 0.; + + fn array_item(array: &Array) -> Self { + unsafe { mlx_sys::mlx_array_item_float32(array.c_array) } + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_float32(array.c_array) } + } +} + +impl Sealed for Complex32 {} +unsafe impl Element for Complex32 { + const KIND: Kind = Kind::Complex64; + const ZERO: Self = Complex::new(0., 0.); + + fn array_item(array: &Array) -> Self { + bindgen_complex_to_complex(unsafe { mlx_sys::mlx_array_item_complex64(array.c_array) }) + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_complex64(array.c_array) as *const Self } + } +} + +impl Sealed for half::f16 {} +unsafe impl Element for half::f16 { + const KIND: Kind = Kind::Float16; + const ZERO: Self = half::f16::ZERO; + + fn array_item(array: &Array) -> Self { + Self::from_bits(unsafe { mlx_sys::mlx_array_item_float16(array.c_array).0 }) + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_float16(array.c_array) as *const Self } + } +} + +impl Sealed for half::bf16 {} +unsafe impl Element for half::bf16 { + const KIND: Kind = Kind::Bfloat16; + const ZERO: Self = half::bf16::ZERO; + + fn array_item(array: &Array) -> Self { + Self::from_bits(unsafe { mlx_sys::mlx_array_item_bfloat16(array.c_array) }) + } + + fn array_data(array: &Array) -> *const Self { + unsafe { mlx_sys::mlx_array_data_bfloat16(array.c_array) as *const Self } + } +} + +#[inline] +fn bindgen_complex_to_complex(item: mlx_sys::__BindgenComplex) -> Complex { + Complex { + re: item.re, + im: item.im, + } +} diff --git a/src/array/mod.rs b/src/array/mod.rs new file mode 100644 index 000000000..7e3a86afd --- /dev/null +++ b/src/array/mod.rs @@ -0,0 +1,37 @@ +use crate::array::shape::Shape; + +mod kind; +pub mod ops; +mod shape; +mod wrapper; + +pub struct MLXArray { + pub tensor: wrapper::Array, + phantom: std::marker::PhantomData, +} + +impl MLXArray { + pub fn eval(&mut self) { + self.tensor.eval(); + } + + pub fn shape(&self) -> Shape { + Shape::from(self.tensor.shape()) + } + + pub fn as_slice(&self) -> Option<&[E]> { + self.tensor.as_slice() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::stream::StreamOrDevice; + + #[test] + fn test_shape() { + let array: MLXArray = MLXArray::zeros([2, 3], StreamOrDevice::default()); + assert_eq!(array.shape().dims, [2, 3]); + } +} diff --git a/src/array/ops.rs b/src/array/ops.rs new file mode 100644 index 000000000..8f63d580d --- /dev/null +++ b/src/array/ops.rs @@ -0,0 +1,30 @@ +use crate::array::shape::Shape; +use crate::array::{kind, wrapper, MLXArray}; +use crate::stream::StreamOrDevice; + +impl MLXArray { + pub fn zeros>>(shape: S, stream: StreamOrDevice) -> Self { + let shape = shape.into(); + let tensor = wrapper::Array::zeros(&shape.dims, E::KIND, stream); + + Self { + tensor, + phantom: std::marker::PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::stream::StreamOrDevice; + + #[test] + fn test_zeros() { + let mut array: MLXArray = MLXArray::zeros([2, 3], StreamOrDevice::default()); + array.eval(); + let data = array.as_slice().unwrap(); + + assert_eq!(data, &[0.0; 6]); + } +} diff --git a/src/array/shape.rs b/src/array/shape.rs new file mode 100644 index 000000000..83e749747 --- /dev/null +++ b/src/array/shape.rs @@ -0,0 +1,53 @@ +use derive_new::new; + +#[derive(new, Debug, Clone, PartialEq, Eq)] +pub struct Shape { + /// The dimensions of the tensor. + pub dims: [usize; D], +} + +impl From<[usize; D]> for Shape { + fn from(dims: [usize; D]) -> Self { + Shape::new(dims) + } +} + +impl From> for Shape { + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim as usize; + } + Self::new(dims) + } +} + +impl From> for Shape { + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim as usize; + } + Self::new(dims) + } +} + +impl From> for Shape { + fn from(shape: Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.into_iter().enumerate() { + dims[i] = dim; + } + Self::new(dims) + } +} + +impl From<&Vec> for Shape { + fn from(shape: &Vec) -> Self { + let mut dims = [1; D]; + for (i, dim) in shape.iter().enumerate() { + dims[i] = *dim; + } + Self::new(dims) + } +} diff --git a/src/array.rs b/src/array/wrapper/mod.rs similarity index 67% rename from src/array.rs rename to src/array/wrapper/mod.rs index 3a4af69d1..4e7b06a2b 100644 --- a/src/array.rs +++ b/src/array/wrapper/mod.rs @@ -1,140 +1,13 @@ -use std::ffi::c_void; - -use half::{bf16, f16}; -use mlx_sys::mlx_array; -use num_complex::Complex; - -use crate::{dtype::Dtype, sealed::Sealed}; - -// TODO: camel case? -// Not using Complex64 because `num_complex::Complex64` is actually Complex -#[allow(non_camel_case_types)] -pub type complex64 = Complex; - -/// A marker trait for array elements. -pub trait ArrayElement: Sealed { - const DTYPE: Dtype; - - fn scalar_array_item(array: &Array) -> Self; - - fn array_data(array: &Array) -> *const Self; -} - -macro_rules! impl_array_element { - ($type:ty, $dtype:expr, $mlx_item_fn:ident, $mlx_data_fn:ident) => { - impl Sealed for $type {} - impl ArrayElement for $type { - const DTYPE: Dtype = $dtype; - - fn scalar_array_item(array: &Array) -> Self { - unsafe { mlx_sys::$mlx_item_fn(array.c_array) } - } - - fn array_data(array: &Array) -> *const Self { - unsafe { mlx_sys::$mlx_data_fn(array.c_array) } - } - } - }; -} - -impl_array_element!(bool, Dtype::Bool, mlx_array_item_bool, mlx_array_data_bool); -impl_array_element!(u8, Dtype::Uint8, mlx_array_item_uint8, mlx_array_data_uint8); -impl_array_element!( - u16, - Dtype::Uint16, - mlx_array_item_uint16, - mlx_array_data_uint16 -); -impl_array_element!( - u32, - Dtype::Uint32, - mlx_array_item_uint32, - mlx_array_data_uint32 -); -impl_array_element!( - u64, - Dtype::Uint64, - mlx_array_item_uint64, - mlx_array_data_uint64 -); -impl_array_element!(i8, Dtype::Int8, mlx_array_item_int8, mlx_array_data_int8); -impl_array_element!( - i16, - Dtype::Int16, - mlx_array_item_int16, - mlx_array_data_int16 -); -impl_array_element!( - i32, - Dtype::Int32, - mlx_array_item_int32, - mlx_array_data_int32 -); -impl_array_element!( - i64, - Dtype::Int64, - mlx_array_item_int64, - mlx_array_data_int64 -); -impl_array_element!( - f32, - Dtype::Float32, - mlx_array_item_float32, - mlx_array_data_float32 -); - -impl Sealed for f16 {} - -impl ArrayElement for f16 { - const DTYPE: Dtype = Dtype::Float16; - - fn scalar_array_item(array: &Array) -> Self { - let val = unsafe { mlx_sys::mlx_array_item_float16(array.c_array) }; - f16::from_bits(val.0) - } - - fn array_data(array: &Array) -> *const Self { - unsafe { mlx_sys::mlx_array_data_float16(array.c_array) as *const Self } - } -} - -impl Sealed for bf16 {} - -impl ArrayElement for bf16 { - const DTYPE: Dtype = Dtype::Bfloat16; +mod ops; - fn scalar_array_item(array: &Array) -> Self { - let val = unsafe { mlx_sys::mlx_array_item_bfloat16(array.c_array) }; - bf16::from_bits(val) - } - - fn array_data(array: &Array) -> *const Self { - unsafe { mlx_sys::mlx_array_data_bfloat16(array.c_array) as *const Self } - } -} - -impl Sealed for complex64 {} - -impl ArrayElement for complex64 { - const DTYPE: Dtype = Dtype::Complex64; - - fn scalar_array_item(array: &Array) -> Self { - let bindgen_complex64 = unsafe { mlx_sys::mlx_array_item_complex64(array.c_array) }; - - Self { - re: bindgen_complex64.re, - im: bindgen_complex64.im, - } - } - - fn array_data(array: &Array) -> *const Self { - // complex64 has the same memory layout as __BindgenComplex - unsafe { mlx_sys::mlx_array_data_complex64(array.c_array) as *const Self } - } -} +use crate::{array::kind, array::kind::Kind}; +use num_complex::Complex32; +use std::ffi::c_void; +// TODO: Clone should probably NOT be implemented because the underlying pointer is atomically +// reference counted but not guarded by a mutex. pub struct Array { - c_array: mlx_array, + pub(super) c_array: mlx_sys::mlx_array, } impl std::fmt::Debug for Array { @@ -145,9 +18,6 @@ impl std::fmt::Debug for Array { } } -// TODO: Clone should probably NOT be implemented because the underlying pointer is atomically -// reference counted but not guarded by a mutex. - impl Drop for Array { fn drop(&mut self) { // TODO: check memory leak with some tool? @@ -164,12 +34,12 @@ impl Array { /// /// The caller must ensure the reference count of the array is properly incremented with /// `mlx_sys::mlx_retain`. - pub unsafe fn from_ptr(c_array: mlx_array) -> Array { + pub unsafe fn from_ptr(c_array: mlx_sys::mlx_array) -> Array { Self { c_array } } // TODO: should this be unsafe? - pub fn as_ptr(&self) -> mlx_array { + pub fn as_ptr(&self) -> mlx_sys::mlx_array { self.c_array } @@ -192,7 +62,7 @@ impl Array { } /// New array from a complex scalar. - pub fn from_complex(val: complex64) -> Array { + pub fn from_complex(val: Complex32) -> Array { let c_array = unsafe { mlx_sys::mlx_array_from_complex(val.re, val.im) }; Array { c_array } } @@ -208,7 +78,7 @@ impl Array { /// /// - Panics if the product of the shape is not equal to the length of the data. /// - Panics if the shape is too large. - pub fn from_slice(data: &[T], shape: &[i32]) -> Self { + pub fn from_slice(data: &[T], shape: &[i32]) -> Self { let dim = if shape.len() > i32::MAX as usize { panic!("Shape is too large") } else { @@ -223,7 +93,7 @@ impl Array { data.as_ptr() as *const c_void, shape.as_ptr(), dim, - T::DTYPE as u32, + T::KIND.into(), ) }; @@ -291,9 +161,9 @@ impl Array { } /// The array element type. - pub fn dtype(&self) -> Dtype { + pub fn dtype(&self) -> Kind { let dtype = unsafe { mlx_sys::mlx_array_get_dtype(self.c_array) }; - Dtype::try_from(dtype).unwrap() + Kind::try_from(dtype).unwrap() } // TODO: document that mlx is lazy @@ -304,15 +174,15 @@ impl Array { } /// Access the value of a scalar array. - pub fn item(&self) -> T { + pub fn item(&self) -> T { // TODO: check and perform type conversion from the inner type to the desired output type - T::scalar_array_item(self) + T::array_item(self) } /// Returns a pointer to the array data /// /// Returns `None` if the array is not evaluated. - pub fn as_slice(&self) -> Option<&[T]> { + pub fn as_slice(&self) -> Option<&[T]> { // TODO: type conversion from the inner type to the desired output type let data = T::array_data(self); @@ -356,7 +226,7 @@ mod tests { assert_eq!(array.nbytes(), 1); assert_eq!(array.ndim(), 0); assert!(array.shape().is_empty()); - assert_eq!(array.dtype(), Dtype::Bool); + assert_eq!(array.dtype(), Kind::Bool); } #[test] @@ -369,7 +239,7 @@ mod tests { assert_eq!(array.nbytes(), 4); assert_eq!(array.ndim(), 0); assert!(array.shape().is_empty()); - assert_eq!(array.dtype(), Dtype::Int32); + assert_eq!(array.dtype(), Kind::Int32); } #[test] @@ -382,21 +252,21 @@ mod tests { assert_eq!(array.nbytes(), 4); assert_eq!(array.ndim(), 0); assert!(array.shape().is_empty()); - assert_eq!(array.dtype(), Dtype::Float32); + assert_eq!(array.dtype(), Kind::Float32); } #[test] fn new_scalar_array_from_complex() { - let val = complex64 { re: 1.0, im: 2.0 }; + let val = Complex32::new(1.0, 2.0); let array = Array::from_complex(val); - assert_eq!(array.item::(), val); + assert_eq!(array.item::(), val); assert_eq!(array.item_size(), 8); assert_eq!(array.size(), 1); assert!(array.strides().is_empty()); assert_eq!(array.nbytes(), 8); assert_eq!(array.ndim(), 0); assert!(array.shape().is_empty()); - assert_eq!(array.dtype(), Dtype::Complex64); + assert_eq!(array.dtype(), Kind::Complex64); } #[test] @@ -412,7 +282,7 @@ mod tests { assert_eq!(array.ndim(), 1); assert_eq!(array.dim(0), 1); assert_eq!(array.shape(), &[1]); - assert_eq!(array.dtype(), Dtype::Int32); + assert_eq!(array.dtype(), Kind::Int32); } #[test] @@ -427,7 +297,7 @@ mod tests { assert_eq!(array.ndim(), 1); assert_eq!(array.dim(0), 5); assert_eq!(array.shape(), &[5]); - assert_eq!(array.dtype(), Dtype::Int32); + assert_eq!(array.dtype(), Kind::Int32); } #[test] @@ -445,7 +315,7 @@ mod tests { assert_eq!(array.dim(-1), 3); // negative index assert_eq!(array.dim(-2), 2); // negative index assert_eq!(array.shape(), &[2, 3]); - assert_eq!(array.dtype(), Dtype::Int32); + assert_eq!(array.dtype(), Kind::Int32); } // // TODO: fatal runtime error: Rust cannot catch foreign exceptions diff --git a/src/array/wrapper/ops.rs b/src/array/wrapper/ops.rs new file mode 100644 index 000000000..41f50b4cd --- /dev/null +++ b/src/array/wrapper/ops.rs @@ -0,0 +1,36 @@ +use crate::{ + array::{kind::Kind, wrapper::Array}, + stream::StreamOrDevice, +}; + +impl Array { + pub fn zeros(shape: &[usize], kind: Kind, stream: StreamOrDevice) -> Array { + let shape = shape.iter().map(|x| *x as i32).collect::>(); + let ctx = stream.as_ptr(); + + unsafe { + Array::from_ptr(mlx_sys::mlx_zeros( + shape.as_ptr(), + shape.len(), + kind.into(), + ctx, + )) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_zeros() { + let mut array = Array::zeros(&[2, 3], Kind::Float32, StreamOrDevice::default()); + assert_eq!(array.shape(), &[2, 3]); + assert_eq!(array.dtype(), Kind::Float32); + + array.eval(); + let data: &[f32] = array.as_slice().unwrap(); + assert_eq!(data, &[0.0; 6]); + } +} diff --git a/src/dtype.rs b/src/dtype.rs deleted file mode 100644 index 99f630299..000000000 --- a/src/dtype.rs +++ /dev/null @@ -1,20 +0,0 @@ -/// Array element type -#[derive( - Debug, Clone, Copy, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::TryFromPrimitive, -)] -#[repr(u32)] -pub enum Dtype { - Bool = mlx_sys::mlx_array_dtype__MLX_BOOL, - Uint8 = mlx_sys::mlx_array_dtype__MLX_UINT8, - Uint16 = mlx_sys::mlx_array_dtype__MLX_UINT16, - Uint32 = mlx_sys::mlx_array_dtype__MLX_UINT32, - Uint64 = mlx_sys::mlx_array_dtype__MLX_UINT64, - Int8 = mlx_sys::mlx_array_dtype__MLX_INT8, - Int16 = mlx_sys::mlx_array_dtype__MLX_INT16, - Int32 = mlx_sys::mlx_array_dtype__MLX_INT32, - Int64 = mlx_sys::mlx_array_dtype__MLX_INT64, - Float16 = mlx_sys::mlx_array_dtype__MLX_FLOAT16, - Float32 = mlx_sys::mlx_array_dtype__MLX_FLOAT32, - Bfloat16 = mlx_sys::mlx_array_dtype__MLX_BFLOAT16, - Complex64 = mlx_sys::mlx_array_dtype__MLX_COMPLEX64, -} diff --git a/src/lib.rs b/src/lib.rs index 1471f885a..83fd8f3c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ pub mod array; pub mod device; -pub mod dtype; pub mod stream; mod utils; diff --git a/src/stream.rs b/src/stream.rs index 79e032ca2..ea44b5260 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -18,23 +18,27 @@ impl StreamOrDevice { pub fn new_with_device(device: &Device) -> StreamOrDevice { StreamOrDevice { - stream: Stream::default_stream(device), + stream: Stream::default_stream_for_device(device), } } /// The `[Stream::default_stream()] on the [Device::cpu()] pub fn cpu() -> StreamOrDevice { StreamOrDevice { - stream: Stream::default_stream(&Device::cpu()), + stream: Stream::default_stream_for_device(&Device::cpu()), } } /// The `[Stream::default_stream()] on the [Device::gpu()] pub fn gpu() -> StreamOrDevice { StreamOrDevice { - stream: Stream::default_stream(&Device::gpu()), + stream: Stream::default_stream_for_device(&Device::gpu()), } } + + pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_stream { + self.stream.c_stream + } } impl Default for StreamOrDevice { @@ -63,7 +67,7 @@ pub struct Stream { } impl Stream { - fn new_with_mlx_mlx_stream(stream: mlx_sys::mlx_stream) -> Stream { + fn new_with_mlx_stream(stream: mlx_sys::mlx_stream) -> Stream { Stream { c_stream: stream } } @@ -79,9 +83,9 @@ impl Stream { Stream { c_stream } } - pub fn default_stream(device: &Device) -> Stream { + pub fn default_stream_for_device(device: &Device) -> Stream { let default_stream = unsafe { mlx_sys::mlx_default_stream(device.c_device) }; - Stream::new_with_mlx_mlx_stream(default_stream) + Stream::new_with_mlx_stream(default_stream) } }