diff --git a/mlx-rs/src/array/safetensors.rs b/mlx-rs/src/array/safetensors.rs index bd148102c..2d1b210a1 100644 --- a/mlx-rs/src/array/safetensors.rs +++ b/mlx-rs/src/array/safetensors.rs @@ -90,6 +90,10 @@ impl<'a> TryFrom<&'a Array> for TensorView<'a> { let bits: &[u16] = transmute(data); cast_slice(bits) }, + Dtype::Float64 => { + let data = value.as_slice::(); + cast_slice(data) + }, Dtype::Complex64 => return Err(ConversionError::MlxDtype(Dtype::Complex64)), } };