Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ license = "MIT"
readme = "README.md"

[dependencies]
mlx-sys = { version = "0.10.0-alpha.0", path = "mlx-sys" }
derive-new = "0.6.0"
half = "2"
mlx-sys = { version = "0.10.0-alpha.0", path = "mlx-sys" }
num-complex = "0.4"
num_enum = "0.7.2"

Expand Down
226 changes: 226 additions & 0 deletions src/array/kind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
use crate::array::wrapper::Array;
use crate::sealed::Sealed;
use num_complex::{Complex, Complex32};

/// Array element type
#[derive(
Debug, Clone, Copy, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::TryFromPrimitive,
)]
#[repr(u32)]
pub enum Kind {
Bool = mlx_sys::mlx_array_dtype__MLX_BOOL,
Uint8 = mlx_sys::mlx_array_dtype__MLX_UINT8,
Uint16 = mlx_sys::mlx_array_dtype__MLX_UINT16,
Uint32 = mlx_sys::mlx_array_dtype__MLX_UINT32,
Uint64 = mlx_sys::mlx_array_dtype__MLX_UINT64,
Int8 = mlx_sys::mlx_array_dtype__MLX_INT8,
Int16 = mlx_sys::mlx_array_dtype__MLX_INT16,
Int32 = mlx_sys::mlx_array_dtype__MLX_INT32,
Int64 = mlx_sys::mlx_array_dtype__MLX_INT64,
Float16 = mlx_sys::mlx_array_dtype__MLX_FLOAT16,
Float32 = mlx_sys::mlx_array_dtype__MLX_FLOAT32,
Bfloat16 = mlx_sys::mlx_array_dtype__MLX_BFLOAT16,
Complex64 = mlx_sys::mlx_array_dtype__MLX_COMPLEX64,
}

/// Kinds for tensor elements
///
/// # Safety
/// The specified Kind must be for a type that has the same length as Self.
pub unsafe trait Element: Clone {
const KIND: Kind;
const ZERO: Self;

fn array_item(array: &Array) -> Self;
fn array_data(array: &Array) -> *const Self;
}

impl Sealed for bool {}
unsafe impl Element for bool {
const KIND: Kind = Kind::Bool;
const ZERO: Self = false;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_bool(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_bool(array.c_array) }
}
}

impl Sealed for u8 {}
unsafe impl Element for u8 {
const KIND: Kind = Kind::Uint8;
const ZERO: Self = 0;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_uint8(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_uint8(array.c_array) }
}
}

impl Sealed for u16 {}
unsafe impl Element for u16 {
const KIND: Kind = Kind::Uint16;
const ZERO: Self = 0;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_uint16(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_uint16(array.c_array) }
}
}

impl Sealed for u32 {}
unsafe impl Element for u32 {
const KIND: Kind = Kind::Uint32;
const ZERO: Self = 0;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_uint32(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_uint32(array.c_array) }
}
}

impl Sealed for u64 {}
unsafe impl Element for u64 {
const KIND: Kind = Kind::Uint64;
const ZERO: Self = 0;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_uint64(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_uint64(array.c_array) }
}
}

impl Sealed for i8 {}
unsafe impl Element for i8 {
const KIND: Kind = Kind::Int8;
const ZERO: Self = 0;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_int8(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_int8(array.c_array) }
}
}

impl Sealed for i16 {}
unsafe impl Element for i16 {
const KIND: Kind = Kind::Int16;
const ZERO: Self = 0;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_int16(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_int16(array.c_array) }
}
}

impl Sealed for i32 {}
unsafe impl Element for i32 {
const KIND: Kind = Kind::Int32;
const ZERO: Self = 0;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_int32(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_int32(array.c_array) }
}
}

impl Sealed for i64 {}
unsafe impl Element for i64 {
const KIND: Kind = Kind::Int64;
const ZERO: Self = 0;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_int64(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_int64(array.c_array) }
}
}

impl Sealed for f32 {}
unsafe impl Element for f32 {
const KIND: Kind = Kind::Float32;
const ZERO: Self = 0.;

fn array_item(array: &Array) -> Self {
unsafe { mlx_sys::mlx_array_item_float32(array.c_array) }
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_float32(array.c_array) }
}
}

impl Sealed for Complex32 {}
unsafe impl Element for Complex32 {
const KIND: Kind = Kind::Complex64;
const ZERO: Self = Complex::new(0., 0.);

fn array_item(array: &Array) -> Self {
bindgen_complex_to_complex(unsafe { mlx_sys::mlx_array_item_complex64(array.c_array) })
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_complex64(array.c_array) as *const Self }
}
}

impl Sealed for half::f16 {}
unsafe impl Element for half::f16 {
const KIND: Kind = Kind::Float16;
const ZERO: Self = half::f16::ZERO;

fn array_item(array: &Array) -> Self {
Self::from_bits(unsafe { mlx_sys::mlx_array_item_float16(array.c_array).0 })
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_float16(array.c_array) as *const Self }
}
}

impl Sealed for half::bf16 {}
unsafe impl Element for half::bf16 {
const KIND: Kind = Kind::Bfloat16;
const ZERO: Self = half::bf16::ZERO;

fn array_item(array: &Array) -> Self {
Self::from_bits(unsafe { mlx_sys::mlx_array_item_bfloat16(array.c_array) })
}

fn array_data(array: &Array) -> *const Self {
unsafe { mlx_sys::mlx_array_data_bfloat16(array.c_array) as *const Self }
}
}

#[inline]
fn bindgen_complex_to_complex<T>(item: mlx_sys::__BindgenComplex<T>) -> Complex<T> {
Complex {
re: item.re,
im: item.im,
}
}
37 changes: 37 additions & 0 deletions src/array/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use crate::array::shape::Shape;

mod kind;
pub mod ops;
mod shape;
mod wrapper;

pub struct MLXArray<E: kind::Element, const D: usize> {
pub tensor: wrapper::Array,
phantom: std::marker::PhantomData<E>,
}

impl<E: kind::Element, const D: usize> MLXArray<E, D> {
pub fn eval(&mut self) {
self.tensor.eval();
}

pub fn shape(&self) -> Shape<D> {
Shape::from(self.tensor.shape())
}

pub fn as_slice(&self) -> Option<&[E]> {
self.tensor.as_slice()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::stream::StreamOrDevice;

#[test]
fn test_shape() {
let array: MLXArray<f32, 2> = MLXArray::zeros([2, 3], StreamOrDevice::default());
assert_eq!(array.shape().dims, [2, 3]);
}
}
30 changes: 30 additions & 0 deletions src/array/ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::array::shape::Shape;
use crate::array::{kind, wrapper, MLXArray};
use crate::stream::StreamOrDevice;

impl<E: kind::Element, const D: usize> MLXArray<E, D> {
pub fn zeros<S: Into<Shape<D>>>(shape: S, stream: StreamOrDevice) -> Self {
let shape = shape.into();
let tensor = wrapper::Array::zeros(&shape.dims, E::KIND, stream);

Self {
tensor,
phantom: std::marker::PhantomData,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::stream::StreamOrDevice;

#[test]
fn test_zeros() {
let mut array: MLXArray<f32, 2> = MLXArray::zeros([2, 3], StreamOrDevice::default());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkout the type validation on this one.

array.eval();
let data = array.as_slice().unwrap();

assert_eq!(data, &[0.0; 6]);
}
}
53 changes: 53 additions & 0 deletions src/array/shape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use derive_new::new;

#[derive(new, Debug, Clone, PartialEq, Eq)]
pub struct Shape<const D: usize> {
/// The dimensions of the tensor.
pub dims: [usize; D],
}

impl<const D: usize> From<[usize; D]> for Shape<D> {
fn from(dims: [usize; D]) -> Self {
Shape::new(dims)
}
}

impl<const D: usize> From<Vec<i64>> for Shape<D> {
fn from(shape: Vec<i64>) -> Self {
let mut dims = [1; D];
for (i, dim) in shape.into_iter().enumerate() {
dims[i] = dim as usize;
}
Self::new(dims)
}
}

impl<const D: usize> From<Vec<u64>> for Shape<D> {
fn from(shape: Vec<u64>) -> Self {
let mut dims = [1; D];
for (i, dim) in shape.into_iter().enumerate() {
dims[i] = dim as usize;
}
Self::new(dims)
}
}

impl<const D: usize> From<Vec<usize>> for Shape<D> {
fn from(shape: Vec<usize>) -> Self {
let mut dims = [1; D];
for (i, dim) in shape.into_iter().enumerate() {
dims[i] = dim;
}
Self::new(dims)
}
}

impl<const D: usize> From<&Vec<usize>> for Shape<D> {
fn from(shape: &Vec<usize>) -> Self {
let mut dims = [1; D];
for (i, dim) in shape.iter().enumerate() {
dims[i] = *dim;
}
Self::new(dims)
}
}
Loading