Skip to content

Commit 877a1a7

Browse files
authored
Refactor Tensor methods to use &[usize] instead of &Vec<usize> to reduce unnecessary overhead (#19)
* refactor some parameters and return types of Tensor to use slices instead of Vecs to reduce unnecessary overhead * refactor parameters to use slices instead of Vecs to reduce unnecessary overhead * update README to reflect changes made in refactor
1 parent ff788d4 commit 877a1a7

File tree

10 files changed

+94
-84
lines changed

10 files changed

+94
-84
lines changed

tensor/README.md

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,61 +35,60 @@ use tensor::Tensor;
3535
let data: Vec<f32> = (0..length).map(|v| v as f32 + 10.0).collect();
3636

3737
// Create a Tensor, ensuring data.len() matches the product of shape
38-
let a = Tensor::new(shape.clone(), data).unwrap();
38+
let a = Tensor::new(shape, data).unwrap();
3939
```
4040
2. **Zeros and Ones**
4141
```rust
4242
// 4x2 filled with zeros
43-
let zeros_tensor = Tensor::zeros(vec![4, 2]);
43+
let zeros_tensor = Tensor::zeros(&[4, 2]);
4444

4545
// 1x9x2x5 filled with ones
46-
let ones_tensor = Tensor::ones(vec![1, 9, 2, 5]);
46+
let ones_tensor = Tensor::ones(&[1, 9, 2, 5]);
4747
```
4848

4949
### 2. Indexing and Accessing Data
5050

5151
- **Indexing** via `Index<Vec<usize>>`:
5252
```rust
53-
let t = Tensor::ones(vec![2, 3]);
54-
// t[vec![row, col]]
55-
assert_eq!(t[vec![0, 2]], 1.0_f32);
53+
let t = Tensor::ones(&[2, 3]);
54+
// t[&[row, col]]
55+
assert_eq!(t[&[0, 2]], 1.0_f32);
5656
```
5757
- **Check `.shape()` and `.strides()`**:
5858
```rust
59-
println!("Shape = {:?}", t.shape()); // e.g. [2, 3]
60-
println!("Strides = {:?}", t.strides()); // e.g. [3, 1]
59+
println!("Shape = {:?}", t.shape()); // [2, 3]
60+
println!("Strides = {:?}", t.strides()); // [3, 1]
6161
```
6262
- **Access raw data**:
6363
```rust
64-
let data_ref: &Vec<f32> = t.data();
65-
// or get mutable reference with t.data_mut()
64+
let data_ref: &[f32] = t.data();
6665
```
6766

6867
### 3. Movement Ops: Reshaping, Permuting, Flattening, Transposing
6968

7069
1. **Reshape**
7170
```rust
72-
let mut a = Tensor::ones(vec![4, 2]);
73-
a.reshape(vec![2, 2, 2]).unwrap();
71+
let mut a = Tensor::ones(&[4, 2]);
72+
a.reshape(&[2, 2, 2]).unwrap();
7473
// shape is now [2, 2, 2]
7574
```
7675
2. **Permute** (change dimension ordering)
7776
```rust
78-
let mut a = Tensor::ones(vec![1, 4, 2]);
77+
let mut a = Tensor::ones(&[1, 4, 2]);
7978
// reorder dimensions to [1->4, 4->2, 2->1]
80-
a.permute(vec![1, 2, 0]).unwrap();
79+
a.permute(&[1, 2, 0]).unwrap();
8180
// shape is now [4, 2, 1]
82-
// strides updated accordingly
81+
// strides are now [2, 1, 8] as the data was not changed
8382
```
8483
3. **Flatten**
8584
```rust
86-
let mut a = Tensor::ones(vec![7, 6]);
85+
let mut a = Tensor::ones(&[7, 6]);
8786
a.flatten();
8887
// shape becomes [42], strides is [1]
8988
```
9089
4. **Transpose**
9190
```rust
92-
let mut a = Tensor::new(vec![2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap();
91+
let mut a = Tensor::new(&[2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap();
9392
a.transpose().unwrap();
9493
// shape is now [3, 2], data is reordered
9594
```
@@ -110,15 +109,14 @@ They also have matching **trait operators**:
110109
#### a) Simple Elementwise
111110

112111
```rust
113-
let a = Tensor::ones(vec![2, 3]);
114-
let b = Tensor::ones(vec![2, 3]);
112+
let a = Tensor::ones(&[2, 3]);
113+
let b = Tensor::ones(&[2, 3]);
115114

116115
// Out-of-place
117116
let c = a.add(&b).unwrap(); // or &a + &b
118-
// c has all 2.0 values
119117

120118
// In-place
121-
let mut d = Tensor::zeros(vec![2, 3]);
119+
let mut d = Tensor::zeros(&[2, 3]);
122120
d.add_inplace(&c);
123121
```
124122

@@ -164,13 +162,13 @@ let c = a.matmul(&b).unwrap();
164162
2. **Sum Along a Dimension**
165163
```rust
166164
let reduced = a.sum_dim(1).unwrap();
167-
// For shape [2, 3], sum_dim(1) -> shape [2], data is sum across columns
168-
// e.g. [6., 15.]
165+
// For shape [2, 3], sum_dim(1) -> shape [2],
166+
// data is sum across columns [6., 15.]
169167
```
170168
3. **Mean** and **Mean Along Dimension**
171169
```rust
172-
let m = a.mean(); // overall mean
173-
let m_dim = a.mean_dim(0); // per-row or per-col mean
170+
let m = a.mean();
171+
let m_dim = a.mean_dim(0);
174172
```
175173

176174
### 6. Unary Ops

tensor/src/core/mod.rs

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,54 @@ pub struct Tensor {
1313
}
1414

1515
impl Tensor {
16-
pub fn new(shape: Vec<usize>, data: Vec<f32>) -> Result<Self, TensorError> {
17-
let length: usize = shape.iter().product();
16+
pub fn new<S>(shape: S, data: Vec<f32>) -> Result<Self, TensorError>
17+
where
18+
S: Into<Vec<usize>>,
19+
{
20+
let shape_vec = shape.into();
21+
let length: usize = shape_vec.iter().product();
1822
if data.len() != length {
1923
return Err(TensorError::CreationError(
2024
"Data does not fit within shape".to_string(),
2125
));
2226
}
23-
let strides: Vec<usize> = calculate_strides(&shape);
27+
let strides: Vec<usize> = calculate_strides(&shape_vec);
2428
Ok(Tensor {
25-
shape: shape.to_vec(),
29+
shape: shape_vec,
2630
strides,
2731
data,
2832
})
2933
}
3034

31-
pub fn zeros(shape: Vec<usize>) -> Self {
32-
let num_elements: usize = shape.iter().product();
33-
let strides: Vec<usize> = calculate_strides(&shape);
35+
pub fn zeros<S>(shape: S) -> Self
36+
where
37+
S: Into<Vec<usize>>,
38+
{
39+
let shape_vec = shape.into();
40+
let num_elements: usize = shape_vec.iter().product();
41+
let strides: Vec<usize> = calculate_strides(&shape_vec);
3442
Tensor {
35-
shape: shape.to_vec(),
43+
shape: shape_vec,
3644
strides,
3745
data: vec![0.0; num_elements],
3846
}
3947
}
4048

41-
pub fn ones(shape: Vec<usize>) -> Self {
42-
let num_elements: usize = shape.iter().product();
43-
let strides: Vec<usize> = calculate_strides(&shape);
49+
pub fn ones<S>(shape: S) -> Self
50+
where
51+
S: Into<Vec<usize>>,
52+
{
53+
let shape_vec = shape.into();
54+
let num_elements: usize = shape_vec.iter().product();
55+
let strides: Vec<usize> = calculate_strides(&shape_vec);
4456
Tensor {
45-
shape: shape.to_vec(),
57+
shape: shape_vec,
4658
strides,
4759
data: vec![1.0; num_elements],
4860
}
4961
}
5062

51-
pub fn get(&self, indices: Vec<usize>) -> Option<&f32> {
63+
pub fn get(&self, indices: &[usize]) -> Option<&f32> {
5264
if indices.len() != self.shape.len() {
5365
return None;
5466
}
@@ -63,15 +75,15 @@ impl Tensor {
6375
self.data.get(idx)
6476
}
6577

66-
pub fn shape(&self) -> &Vec<usize> {
78+
pub fn shape(&self) -> &[usize] {
6779
&self.shape
6880
}
6981

70-
pub fn strides(&self) -> &Vec<usize> {
82+
pub fn strides(&self) -> &[usize] {
7183
&self.strides
7284
}
7385

74-
pub fn data(&self) -> &Vec<f32> {
86+
pub fn data(&self) -> &[f32] {
7587
&self.data
7688
}
7789

tensor/src/core/ops/binary.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl Tensor {
2323
.zip(other.data.iter())
2424
.map(|(a, b)| op(*a, *b))
2525
.collect();
26-
return Tensor::new(self_shape.clone(), result_data);
26+
return Tensor::new(self_shape, result_data);
2727
}
2828

2929
let (bc_shape, self_bc_strides, other_bc_strides) =
@@ -102,8 +102,8 @@ impl Tensor {
102102
}
103103

104104
pub fn matmul(&self, other: &Tensor) -> Result<Tensor, TensorError> {
105-
let lhs_shape: &Vec<usize> = self.shape();
106-
let rhs_shape: &Vec<usize> = other.shape();
105+
let lhs_shape = self.shape();
106+
let rhs_shape = other.shape();
107107
if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
108108
return Err(TensorError::BroadcastError(
109109
"matmul requires 2D tensors".to_string(),
@@ -118,8 +118,8 @@ impl Tensor {
118118
));
119119
}
120120

121-
let lhs_data: &Vec<f32> = self.data();
122-
let rhs_data: &Vec<f32> = other.data();
121+
let lhs_data = self.data();
122+
let rhs_data = other.data();
123123
let mut result_data: Vec<f32> = vec![0.0_f32; rows_left * cols_right];
124124
for i in 0..rows_left {
125125
for k in 0..cols_left {

tensor/src/core/ops/broadcast.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
pub fn is_broadcastable(a: &Vec<usize>, b: &Vec<usize>) -> bool {
1+
pub fn is_broadcastable(a_shape: &[usize], b_shape: &[usize]) -> bool {
22
// This is based on NumPy's rules: https://numpy.org/doc/stable/user/basics.broadcasting.html
3-
for (i, j) in a.iter().rev().zip(b.iter().rev()) {
3+
for (i, j) in a_shape.iter().rev().zip(b_shape.iter().rev()) {
44
if *i == 1 || *j == 1 {
55
continue;
66
}
@@ -12,8 +12,8 @@ pub fn is_broadcastable(a: &Vec<usize>, b: &Vec<usize>) -> bool {
1212
}
1313

1414
pub fn compute_broadcast_shape_and_strides(
15-
a_shape: &Vec<usize>,
16-
b_shape: &Vec<usize>,
15+
a_shape: &[usize],
16+
b_shape: &[usize],
1717
) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
1818
let ndims = a_shape.len().max(b_shape.len());
1919
let mut a_bc_strides = vec![1; ndims];

tensor/src/core/ops/movement.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@ use crate::core::{calculate_strides, TensorError};
22
use crate::Tensor;
33

44
impl Tensor {
5-
pub fn reshape(&mut self, shape: Vec<usize>) -> Result<(), TensorError> {
5+
pub fn reshape(&mut self, shape: &[usize]) -> Result<(), TensorError> {
66
let new_length: usize = shape.iter().product();
77
let current_length: usize = self.shape.iter().product();
88
if new_length != current_length {
99
return Err(TensorError::CreationError(
1010
"The new shape does not align with the size of the data.".to_string(),
1111
));
1212
}
13-
self.strides = calculate_strides(&shape);
13+
self.strides = calculate_strides(shape);
1414
self.shape = shape.to_vec();
1515
Ok(())
1616
}
1717

18-
pub fn permute(&mut self, order: Vec<usize>) -> Result<(), TensorError> {
18+
pub fn permute(&mut self, order: &[usize]) -> Result<(), TensorError> {
1919
if order.len() != self.shape.len() {
2020
return Err(TensorError::CreationError(
2121
"The permutation does not align with the current shape.".to_string(),

tensor/src/core/ops/reduce.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ impl Tensor {
1818
));
1919
}
2020

21-
let mut result_shape = self_shape.clone();
21+
let mut result_shape = self_shape.to_vec().clone();
2222
let dim_size = result_shape.remove(dim);
2323

2424
let result_size: usize = result_shape.iter().product();

tensor/src/core/ops/unary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ impl Tensor {
66
F: Fn(f32) -> f32,
77
{
88
let result_data: Vec<f32> = self.data().iter().map(|x| op(*x)).collect();
9-
return Tensor::new(self.shape().clone(), result_data).unwrap();
9+
return Tensor::new(self.shape(), result_data).unwrap();
1010
}
1111

1212
pub fn exp(&self) -> Tensor {

tensor/src/core/traits.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ impl fmt::Display for Tensor {
2929
}
3030
}
3131

32-
impl Index<Vec<usize>> for Tensor {
32+
impl Index<&[usize]> for Tensor {
3333
type Output = f32;
34-
fn index(&self, indices: Vec<usize>) -> &Self::Output {
34+
fn index(&self, indices: &[usize]) -> &Self::Output {
3535
self.get(indices).expect("Index out of bounds")
3636
}
3737
}
@@ -52,7 +52,7 @@ impl Add<f32> for &Tensor {
5252

5353
fn add(self, rhs: f32) -> Self::Output {
5454
let result_data: Vec<f32> = self.data().iter().map(|&x| x + rhs).collect();
55-
Tensor::new(self.shape().clone(), result_data).unwrap()
55+
Tensor::new(self.shape(), result_data).unwrap()
5656
}
5757
}
5858

@@ -61,7 +61,7 @@ impl Add<&Tensor> for f32 {
6161

6262
fn add(self, rhs: &Tensor) -> Self::Output {
6363
let result_data: Vec<f32> = rhs.data().iter().map(|&x| x + self).collect();
64-
Tensor::new(rhs.shape().clone(), result_data).unwrap()
64+
Tensor::new(rhs.shape(), result_data).unwrap()
6565
}
6666
}
6767

@@ -103,7 +103,7 @@ impl Sub<f32> for &Tensor {
103103

104104
fn sub(self, rhs: f32) -> Self::Output {
105105
let result_data: Vec<f32> = self.data().iter().map(|&x| x - rhs).collect();
106-
Tensor::new(self.shape().clone(), result_data).unwrap()
106+
Tensor::new(self.shape(), result_data).unwrap()
107107
}
108108
}
109109

@@ -112,7 +112,7 @@ impl Sub<&Tensor> for f32 {
112112

113113
fn sub(self, rhs: &Tensor) -> Self::Output {
114114
let result_data: Vec<f32> = rhs.data().iter().map(|&x| x - self).collect();
115-
Tensor::new(rhs.shape().clone(), result_data).unwrap()
115+
Tensor::new(rhs.shape(), result_data).unwrap()
116116
}
117117
}
118118

@@ -154,7 +154,7 @@ impl Mul<f32> for &Tensor {
154154

155155
fn mul(self, rhs: f32) -> Self::Output {
156156
let result_data: Vec<f32> = self.data().iter().map(|&x| x * rhs).collect();
157-
Tensor::new(self.shape().clone(), result_data).unwrap()
157+
Tensor::new(self.shape(), result_data).unwrap()
158158
}
159159
}
160160

@@ -163,7 +163,7 @@ impl Mul<&Tensor> for f32 {
163163

164164
fn mul(self, rhs: &Tensor) -> Self::Output {
165165
let result_data: Vec<f32> = rhs.data().iter().map(|&x| x * self).collect();
166-
Tensor::new(rhs.shape().clone(), result_data).unwrap()
166+
Tensor::new(rhs.shape(), result_data).unwrap()
167167
}
168168
}
169169

@@ -205,7 +205,7 @@ impl Div<f32> for &Tensor {
205205

206206
fn div(self, rhs: f32) -> Self::Output {
207207
let result_data: Vec<f32> = self.data().iter().map(|&x| x / rhs).collect();
208-
Tensor::new(self.shape().clone(), result_data).unwrap()
208+
Tensor::new(self.shape(), result_data).unwrap()
209209
}
210210
}
211211

tensor/src/core/utils.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt;
22

3-
pub fn calculate_strides(shape: &Vec<usize>) -> Vec<usize> {
3+
pub fn calculate_strides(shape: &[usize]) -> Vec<usize> {
44
let length: usize = shape.len();
55
let mut strides = vec![1; length];
66
strides.iter_mut().enumerate().for_each(|(i, stride)| {

0 commit comments

Comments
 (0)