-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathops.rs
More file actions
30 lines (25 loc) · 784 Bytes
/
ops.rs
File metadata and controls
30 lines (25 loc) · 784 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
use crate::array::shape::Shape;
use crate::array::{kind, wrapper, MLXArray};
use crate::stream::StreamOrDevice;
impl<E: kind::Element, const D: usize> MLXArray<E, D> {
pub fn zeros<S: Into<Shape<D>>>(shape: S, stream: StreamOrDevice) -> Self {
let shape = shape.into();
let tensor = wrapper::Array::zeros(&shape.dims, E::KIND, stream);
Self {
tensor,
phantom: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stream::StreamOrDevice;
#[test]
fn test_zeros() {
let mut array: MLXArray<f32, 2> = MLXArray::zeros([2, 3], StreamOrDevice::default());
array.eval();
let data = array.as_slice().unwrap();
assert_eq!(data, &[0.0; 6]);
}
}