Skip to content

Commit 5d68c8a

Browse files
committed
refactor binary ops to use a generic funtion and a closure. add tensor error types
1 parent 152a150 commit 5d68c8a

File tree

1 file changed

+76
-153
lines changed

1 file changed

+76
-153
lines changed

tensor/src/lib.rs

Lines changed: 76 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,34 @@ pub struct Tensor {
88
data: Vec<f32>,
99
}
1010

11+
#[derive(Debug)]
12+
pub enum TensorError {
13+
BroadcastError(String),
14+
CreationError(String),
15+
MovementError(String),
16+
IndexError(String),
17+
}
18+
19+
impl std::fmt::Display for TensorError {
20+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21+
match self {
22+
TensorError::CreationError(msg) => write!(f, "Creation error: {}", msg),
23+
TensorError::BroadcastError(msg) => write!(f, "Broadcast error: {}", msg),
24+
TensorError::MovementError(msg) => write!(f, "Movement error: {}", msg),
25+
TensorError::IndexError(msg) => write!(f, "Index error: {}", msg),
26+
}
27+
}
28+
}
29+
30+
impl std::error::Error for TensorError {}
31+
1132
impl Tensor {
12-
pub fn new(shape: Vec<usize>, data: Vec<f32>) -> Result<Self, &'static str> {
33+
pub fn new(shape: Vec<usize>, data: Vec<f32>) -> Result<Tensor, TensorError> {
1334
let length: usize = shape.iter().product();
1435
if data.len() != length {
15-
return Err("Data does not fit within shape");
36+
return Err(TensorError::CreationError(
37+
"Data does not fit within shape".to_string(),
38+
));
1639
}
1740
let strides: Vec<usize> = Self::calculate_strides(&shape);
1841
Ok(Tensor {
@@ -54,26 +77,32 @@ impl Tensor {
5477

5578
/* MOVEMENT OPS */
5679

57-
pub fn reshape(&mut self, shape: Vec<usize>) -> Result<(), &'static str> {
80+
pub fn reshape(&mut self, shape: Vec<usize>) -> Result<(), TensorError> {
5881
let new_length: usize = shape.iter().product();
5982
let current_length: usize = self.shape.iter().product();
6083
if new_length != current_length {
61-
return Err("The new shape does not align with the size of the data.");
84+
return Err(TensorError::CreationError(
85+
"The new shape does not align with the size of the data.".to_string(),
86+
));
6287
}
6388
self.strides = Self::calculate_strides(&shape);
6489
self.shape = shape.to_vec();
6590
Ok(())
6691
}
6792

68-
pub fn permute(&mut self, order: Vec<usize>) -> Result<(), &'static str> {
93+
pub fn permute(&mut self, order: Vec<usize>) -> Result<(), TensorError> {
6994
if order.len() != self.shape.len() {
70-
return Err("The permutation does not align with the current shape.");
95+
return Err(TensorError::CreationError(
96+
"The permutation does not align with the current shape.".to_string(),
97+
));
7198
}
7299

73100
let mut sorted_order: Vec<usize> = order.to_vec();
74101
sorted_order.sort();
75102
if sorted_order != (0..self.shape.len()).collect::<Vec<_>>() {
76-
return Err("Index out of range for shape.");
103+
return Err(TensorError::CreationError(
104+
"Index out of range for shape.".to_string(),
105+
));
77106
}
78107

79108
let new_shape: Vec<usize> = order.iter().map(|&i| self.shape[i]).collect();
@@ -89,9 +118,11 @@ impl Tensor {
89118
self.strides = vec![1];
90119
}
91120

92-
pub fn transpose(&mut self) -> Result<(), &'static str> {
121+
pub fn transpose(&mut self) -> Result<(), TensorError> {
93122
if self.shape.len() != 2 {
94-
return Err("transpose only supports 2D tensors currently.");
123+
return Err(TensorError::CreationError(
124+
"transpose only supports 2D tensors currently.".to_string(),
125+
));
95126
}
96127

97128
let (m, n) = (self.shape[0], self.shape[1]);
@@ -116,12 +147,14 @@ impl Tensor {
116147
Tensor::new(vec![1], vec![sum]).unwrap()
117148
}
118149

119-
pub fn sum_dim(&self, dim: usize) -> Result<Tensor, &'static str> {
150+
pub fn sum_dim(&self, dim: usize) -> Result<Tensor, TensorError> {
120151
let self_data = self.data();
121152
let self_shape = self.shape();
122153
let self_strides = self.strides();
123154
if self_shape.len() < dim {
124-
return Err("Dimension out of range for the tensor");
155+
return Err(TensorError::IndexError(
156+
"Dimension out of range for the tensor".to_string(),
157+
));
125158
}
126159

127160
let mut result_shape = self_shape.clone();
@@ -161,9 +194,11 @@ impl Tensor {
161194
&sum / self.shape().iter().product::<usize>() as f32
162195
}
163196

164-
pub fn mean_dim(&self, dim: usize) -> Result<Tensor, &'static str> {
197+
pub fn mean_dim(&self, dim: usize) -> Result<Tensor, TensorError> {
165198
if self.shape().len() < dim {
166-
return Err("Dimension out of range for the tensor");
199+
return Err(TensorError::IndexError(
200+
"Dimension out of range for the tensor".to_string(),
201+
));
167202
}
168203
let sum: Tensor = self.sum_dim(dim).unwrap();
169204
Ok(&sum / self.shape()[dim] as f32)
@@ -204,19 +239,24 @@ impl Tensor {
204239

205240
/* BINARY OPS */
206241

207-
pub fn add(&self, other: &Tensor) -> Result<Tensor, &'static str> {
242+
fn binary_op<F>(&self, other: &Tensor, op: F) -> Result<Tensor, TensorError>
243+
where
244+
F: Fn(f32, f32) -> f32,
245+
{
208246
let self_shape = self.shape();
209247
let other_shape = other.shape();
210248
if !is_broadcastable(self_shape, other_shape) {
211-
return Err("The tensor shapes are not compatible for addition.");
249+
return Err(TensorError::BroadcastError(
250+
"Shapes are not compatible for the operation".to_string(),
251+
));
212252
}
213253

214254
if self_shape == other_shape {
215255
let result_data: Vec<f32> = self
216256
.data
217257
.iter()
218258
.zip(other.data.iter())
219-
.map(|(a, b)| a + b)
259+
.map(|(a, b)| op(*a, *b))
220260
.collect();
221261
return Tensor::new(self_shape.clone(), result_data);
222262
}
@@ -226,29 +266,29 @@ impl Tensor {
226266

227267
let self_data = self.data();
228268
let other_data = other.data();
229-
230269
let result_size: usize = bc_shape.iter().product();
231270
let mut result_data: Vec<f32> = Vec::with_capacity(result_size);
232271

233272
for i in 0..result_size {
234273
let multi_idx = unravel_index(i, &bc_shape);
235-
236274
let mut self_offset = 0;
275+
let mut other_offset = 0;
276+
237277
for (dim_i, &stride) in self_bc_strides.iter().enumerate() {
238278
self_offset += multi_idx[dim_i] * stride;
239279
}
240-
241-
let mut other_offset = 0;
242280
for (dim_i, &stride) in other_bc_strides.iter().enumerate() {
243281
other_offset += multi_idx[dim_i] * stride;
244282
}
245-
246-
let val = self_data[self_offset] + other_data[other_offset];
247-
result_data.push(val);
283+
result_data.push(op(self_data[self_offset], other_data[other_offset]));
248284
}
249285
Tensor::new(bc_shape, result_data)
250286
}
251287

288+
pub fn add(&self, other: &Tensor) -> Result<Tensor, TensorError> {
289+
self.binary_op(other, |a, b| a + b)
290+
}
291+
252292
pub fn add_inplace(&mut self, other: &Tensor) {
253293
let self_shape = self.shape();
254294
let other_shape = other.shape();
@@ -263,49 +303,8 @@ impl Tensor {
263303
});
264304
}
265305

266-
pub fn sub(&self, other: &Tensor) -> Result<Tensor, &'static str> {
267-
let self_shape = self.shape();
268-
let other_shape = other.shape();
269-
if !is_broadcastable(self_shape, other_shape) {
270-
return Err("The tensor shapes are not compatible for subtraction.");
271-
}
272-
273-
if self_shape == other_shape {
274-
let result_data: Vec<f32> = self
275-
.data
276-
.iter()
277-
.zip(other.data.iter())
278-
.map(|(a, b)| a - b)
279-
.collect();
280-
return Tensor::new(self_shape.clone(), result_data);
281-
}
282-
283-
let (bc_shape, self_bc_strides, other_bc_strides) =
284-
compute_broadcast_shape_and_strides(self_shape, other_shape);
285-
286-
let self_data = self.data();
287-
let other_data = other.data();
288-
289-
let result_size: usize = bc_shape.iter().product();
290-
let mut result_data: Vec<f32> = Vec::with_capacity(result_size);
291-
292-
for i in 0..result_size {
293-
let multi_idx = unravel_index(i, &bc_shape);
294-
295-
let mut self_offset = 0;
296-
for (dim_i, &stride) in self_bc_strides.iter().enumerate() {
297-
self_offset += multi_idx[dim_i] * stride;
298-
}
299-
300-
let mut other_offset = 0;
301-
for (dim_i, &stride) in other_bc_strides.iter().enumerate() {
302-
other_offset += multi_idx[dim_i] * stride;
303-
}
304-
305-
let val = self_data[self_offset] - other_data[other_offset];
306-
result_data.push(val);
307-
}
308-
Tensor::new(bc_shape, result_data)
306+
pub fn sub(&self, other: &Tensor) -> Result<Tensor, TensorError> {
307+
self.binary_op(other, |a, b| a - b)
309308
}
310309

311310
pub fn sub_inplace(&mut self, other: &Tensor) {
@@ -322,48 +321,8 @@ impl Tensor {
322321
});
323322
}
324323

325-
pub fn mul(&self, other: &Tensor) -> Result<Tensor, &'static str> {
326-
let self_shape = self.shape();
327-
let other_shape = other.shape();
328-
if !is_broadcastable(self_shape, other_shape) {
329-
return Err("The tensor shapes are not compatible for multiplication.");
330-
}
331-
332-
if self_shape == other_shape {
333-
let result_data: Vec<f32> = self
334-
.data
335-
.iter()
336-
.zip(other.data.iter())
337-
.map(|(a, b)| a * b)
338-
.collect();
339-
return Tensor::new(self_shape.clone(), result_data);
340-
}
341-
342-
let (bc_shape, self_bc_strides, other_bc_strides) =
343-
compute_broadcast_shape_and_strides(self_shape, other_shape);
344-
345-
let result_size = bc_shape.iter().product();
346-
let mut result_data: Vec<f32> = Vec::with_capacity(result_size);
347-
let self_data = self.data();
348-
let other_data = other.data();
349-
350-
for i in 0..result_size {
351-
let multi_idx = unravel_index(i, &bc_shape);
352-
353-
let mut self_offset = 0;
354-
for (dim_i, &stride) in self_bc_strides.iter().enumerate() {
355-
self_offset += multi_idx[dim_i] * stride;
356-
}
357-
358-
let mut other_offset = 0;
359-
for (dim_i, &stride) in other_bc_strides.iter().enumerate() {
360-
other_offset += multi_idx[dim_i] * stride;
361-
}
362-
363-
let val = self_data[self_offset] * other_data[other_offset];
364-
result_data.push(val);
365-
}
366-
Tensor::new(bc_shape, result_data)
324+
pub fn mul(&self, other: &Tensor) -> Result<Tensor, TensorError> {
325+
self.binary_op(other, |a, b| a * b)
367326
}
368327

369328
pub fn mul_inplace(&mut self, other: &Tensor) {
@@ -380,48 +339,8 @@ impl Tensor {
380339
});
381340
}
382341

383-
pub fn div(&self, other: &Tensor) -> Result<Tensor, &'static str> {
384-
let self_shape = self.shape();
385-
let other_shape = other.shape();
386-
if !is_broadcastable(self_shape, other_shape) {
387-
return Err("The tensor shapes are not compatible for division.");
388-
}
389-
390-
if self_shape == other_shape {
391-
let result_data: Vec<f32> = self
392-
.data
393-
.iter()
394-
.zip(other.data.iter())
395-
.map(|(a, b)| a / b)
396-
.collect();
397-
return Tensor::new(self_shape.clone(), result_data);
398-
}
399-
400-
let (bc_shape, self_bc_strides, other_bc_strides) =
401-
compute_broadcast_shape_and_strides(self_shape, other_shape);
402-
403-
let result_size = bc_shape.iter().product();
404-
let mut result_data: Vec<f32> = Vec::with_capacity(result_size);
405-
let self_data = self.data();
406-
let other_data = other.data();
407-
408-
for i in 0..result_size {
409-
let multi_idx = unravel_index(i, &bc_shape);
410-
411-
let mut self_offset = 0;
412-
for (dim_i, &stride) in self_bc_strides.iter().enumerate() {
413-
self_offset += multi_idx[dim_i] * stride;
414-
}
415-
416-
let mut other_offset = 0;
417-
for (dim_i, &stride) in other_bc_strides.iter().enumerate() {
418-
other_offset += multi_idx[dim_i] * stride;
419-
}
420-
421-
let val = self_data[self_offset] / other_data[other_offset];
422-
result_data.push(val);
423-
}
424-
Tensor::new(bc_shape, result_data)
342+
pub fn div(&self, other: &Tensor) -> Result<Tensor, TensorError> {
343+
self.binary_op(other, |a, b| a / b)
425344
}
426345

427346
pub fn div_inplace(&mut self, other: &Tensor) {
@@ -438,17 +357,21 @@ impl Tensor {
438357
});
439358
}
440359

441-
pub fn matmul(&self, other: &Tensor) -> Result<Tensor, &'static str> {
360+
pub fn matmul(&self, other: &Tensor) -> Result<Tensor, TensorError> {
442361
let lhs_shape: &Vec<usize> = self.shape();
443362
let rhs_shape: &Vec<usize> = other.shape();
444363
if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
445-
return Err("matmul requires 2D tensors");
364+
return Err(TensorError::BroadcastError(
365+
"matmul requires 2D tensors".to_string(),
366+
));
446367
}
447368

448369
let (rows_left, cols_left) = (lhs_shape[0], lhs_shape[1]);
449370
let (rows_right, cols_right) = (rhs_shape[0], rhs_shape[1]);
450371
if cols_left != rows_right {
451-
return Err("Incompatible shapes for matrix multiplication");
372+
return Err(TensorError::BroadcastError(
373+
"Incompatible shapes for matrix multiplication".to_string(),
374+
));
452375
}
453376

454377
let lhs_data: &Vec<f32> = self.data();

0 commit comments

Comments
 (0)