@@ -8,11 +8,34 @@ pub struct Tensor {
88 data : Vec < f32 > ,
99}
1010
11+ #[ derive( Debug ) ]
12+ pub enum TensorError {
13+ BroadcastError ( String ) ,
14+ CreationError ( String ) ,
15+ MovementError ( String ) ,
16+ IndexError ( String ) ,
17+ }
18+
19+ impl std:: fmt:: Display for TensorError {
20+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
21+ match self {
22+ TensorError :: CreationError ( msg) => write ! ( f, "Creation error: {}" , msg) ,
23+ TensorError :: BroadcastError ( msg) => write ! ( f, "Broadcast error: {}" , msg) ,
24+ TensorError :: MovementError ( msg) => write ! ( f, "Movement error: {}" , msg) ,
25+ TensorError :: IndexError ( msg) => write ! ( f, "Index error: {}" , msg) ,
26+ }
27+ }
28+ }
29+
30+ impl std:: error:: Error for TensorError { }
31+
1132impl Tensor {
12- pub fn new ( shape : Vec < usize > , data : Vec < f32 > ) -> Result < Self , & ' static str > {
33+ pub fn new ( shape : Vec < usize > , data : Vec < f32 > ) -> Result < Tensor , TensorError > {
1334 let length: usize = shape. iter ( ) . product ( ) ;
1435 if data. len ( ) != length {
15- return Err ( "Data does not fit within shape" ) ;
36+ return Err ( TensorError :: CreationError (
37+ "Data does not fit within shape" . to_string ( ) ,
38+ ) ) ;
1639 }
1740 let strides: Vec < usize > = Self :: calculate_strides ( & shape) ;
1841 Ok ( Tensor {
@@ -54,26 +77,32 @@ impl Tensor {
5477
5578 /* MOVEMENT OPS */
5679
57- pub fn reshape ( & mut self , shape : Vec < usize > ) -> Result < ( ) , & ' static str > {
80+ pub fn reshape ( & mut self , shape : Vec < usize > ) -> Result < ( ) , TensorError > {
5881 let new_length: usize = shape. iter ( ) . product ( ) ;
5982 let current_length: usize = self . shape . iter ( ) . product ( ) ;
6083 if new_length != current_length {
61- return Err ( "The new shape does not align with the size of the data." ) ;
84+ return Err ( TensorError :: CreationError (
85+ "The new shape does not align with the size of the data." . to_string ( ) ,
86+ ) ) ;
6287 }
6388 self . strides = Self :: calculate_strides ( & shape) ;
6489 self . shape = shape. to_vec ( ) ;
6590 Ok ( ( ) )
6691 }
6792
68- pub fn permute ( & mut self , order : Vec < usize > ) -> Result < ( ) , & ' static str > {
93+ pub fn permute ( & mut self , order : Vec < usize > ) -> Result < ( ) , TensorError > {
6994 if order. len ( ) != self . shape . len ( ) {
70- return Err ( "The permutation does not align with the current shape." ) ;
95+ return Err ( TensorError :: CreationError (
96+ "The permutation does not align with the current shape." . to_string ( ) ,
97+ ) ) ;
7198 }
7299
73100 let mut sorted_order: Vec < usize > = order. to_vec ( ) ;
74101 sorted_order. sort ( ) ;
75102 if sorted_order != ( 0 ..self . shape . len ( ) ) . collect :: < Vec < _ > > ( ) {
76- return Err ( "Index out of range for shape." ) ;
103+ return Err ( TensorError :: CreationError (
104+ "Index out of range for shape." . to_string ( ) ,
105+ ) ) ;
77106 }
78107
79108 let new_shape: Vec < usize > = order. iter ( ) . map ( |& i| self . shape [ i] ) . collect ( ) ;
@@ -89,9 +118,11 @@ impl Tensor {
89118 self . strides = vec ! [ 1 ] ;
90119 }
91120
92- pub fn transpose ( & mut self ) -> Result < ( ) , & ' static str > {
121+ pub fn transpose ( & mut self ) -> Result < ( ) , TensorError > {
93122 if self . shape . len ( ) != 2 {
94- return Err ( "transpose only supports 2D tensors currently." ) ;
123+ return Err ( TensorError :: CreationError (
124+ "transpose only supports 2D tensors currently." . to_string ( ) ,
125+ ) ) ;
95126 }
96127
97128 let ( m, n) = ( self . shape [ 0 ] , self . shape [ 1 ] ) ;
@@ -116,12 +147,14 @@ impl Tensor {
116147 Tensor :: new ( vec ! [ 1 ] , vec ! [ sum] ) . unwrap ( )
117148 }
118149
119- pub fn sum_dim ( & self , dim : usize ) -> Result < Tensor , & ' static str > {
150+ pub fn sum_dim ( & self , dim : usize ) -> Result < Tensor , TensorError > {
120151 let self_data = self . data ( ) ;
121152 let self_shape = self . shape ( ) ;
122153 let self_strides = self . strides ( ) ;
123154 if self_shape. len ( ) < dim {
124- return Err ( "Dimension out of range for the tensor" ) ;
155+ return Err ( TensorError :: IndexError (
156+ "Dimension out of range for the tensor" . to_string ( ) ,
157+ ) ) ;
125158 }
126159
127160 let mut result_shape = self_shape. clone ( ) ;
@@ -161,9 +194,11 @@ impl Tensor {
161194 & sum / self . shape ( ) . iter ( ) . product :: < usize > ( ) as f32
162195 }
163196
164- pub fn mean_dim ( & self , dim : usize ) -> Result < Tensor , & ' static str > {
197+ pub fn mean_dim ( & self , dim : usize ) -> Result < Tensor , TensorError > {
165198 if self . shape ( ) . len ( ) < dim {
166- return Err ( "Dimension out of range for the tensor" ) ;
199+ return Err ( TensorError :: IndexError (
200+ "Dimension out of range for the tensor" . to_string ( ) ,
201+ ) ) ;
167202 }
168203 let sum: Tensor = self . sum_dim ( dim) . unwrap ( ) ;
169204 Ok ( & sum / self . shape ( ) [ dim] as f32 )
@@ -204,19 +239,24 @@ impl Tensor {
204239
205240 /* BINARY OPS */
206241
207- pub fn add ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
242+ fn binary_op < F > ( & self , other : & Tensor , op : F ) -> Result < Tensor , TensorError >
243+ where
244+ F : Fn ( f32 , f32 ) -> f32 ,
245+ {
208246 let self_shape = self . shape ( ) ;
209247 let other_shape = other. shape ( ) ;
210248 if !is_broadcastable ( self_shape, other_shape) {
211- return Err ( "The tensor shapes are not compatible for addition." ) ;
249+ return Err ( TensorError :: BroadcastError (
250+ "Shapes are not compatible for the operation" . to_string ( ) ,
251+ ) ) ;
212252 }
213253
214254 if self_shape == other_shape {
215255 let result_data: Vec < f32 > = self
216256 . data
217257 . iter ( )
218258 . zip ( other. data . iter ( ) )
219- . map ( |( a, b) | a + b )
259+ . map ( |( a, b) | op ( * a , * b ) )
220260 . collect ( ) ;
221261 return Tensor :: new ( self_shape. clone ( ) , result_data) ;
222262 }
@@ -226,29 +266,29 @@ impl Tensor {
226266
227267 let self_data = self . data ( ) ;
228268 let other_data = other. data ( ) ;
229-
230269 let result_size: usize = bc_shape. iter ( ) . product ( ) ;
231270 let mut result_data: Vec < f32 > = Vec :: with_capacity ( result_size) ;
232271
233272 for i in 0 ..result_size {
234273 let multi_idx = unravel_index ( i, & bc_shape) ;
235-
236274 let mut self_offset = 0 ;
275+ let mut other_offset = 0 ;
276+
237277 for ( dim_i, & stride) in self_bc_strides. iter ( ) . enumerate ( ) {
238278 self_offset += multi_idx[ dim_i] * stride;
239279 }
240-
241- let mut other_offset = 0 ;
242280 for ( dim_i, & stride) in other_bc_strides. iter ( ) . enumerate ( ) {
243281 other_offset += multi_idx[ dim_i] * stride;
244282 }
245-
246- let val = self_data[ self_offset] + other_data[ other_offset] ;
247- result_data. push ( val) ;
283+ result_data. push ( op ( self_data[ self_offset] , other_data[ other_offset] ) ) ;
248284 }
249285 Tensor :: new ( bc_shape, result_data)
250286 }
251287
288+ pub fn add ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
289+ self . binary_op ( other, |a, b| a + b)
290+ }
291+
252292 pub fn add_inplace ( & mut self , other : & Tensor ) {
253293 let self_shape = self . shape ( ) ;
254294 let other_shape = other. shape ( ) ;
@@ -263,49 +303,8 @@ impl Tensor {
263303 } ) ;
264304 }
265305
266- pub fn sub ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
267- let self_shape = self . shape ( ) ;
268- let other_shape = other. shape ( ) ;
269- if !is_broadcastable ( self_shape, other_shape) {
270- return Err ( "The tensor shapes are not compatible for subtraction." ) ;
271- }
272-
273- if self_shape == other_shape {
274- let result_data: Vec < f32 > = self
275- . data
276- . iter ( )
277- . zip ( other. data . iter ( ) )
278- . map ( |( a, b) | a - b)
279- . collect ( ) ;
280- return Tensor :: new ( self_shape. clone ( ) , result_data) ;
281- }
282-
283- let ( bc_shape, self_bc_strides, other_bc_strides) =
284- compute_broadcast_shape_and_strides ( self_shape, other_shape) ;
285-
286- let self_data = self . data ( ) ;
287- let other_data = other. data ( ) ;
288-
289- let result_size: usize = bc_shape. iter ( ) . product ( ) ;
290- let mut result_data: Vec < f32 > = Vec :: with_capacity ( result_size) ;
291-
292- for i in 0 ..result_size {
293- let multi_idx = unravel_index ( i, & bc_shape) ;
294-
295- let mut self_offset = 0 ;
296- for ( dim_i, & stride) in self_bc_strides. iter ( ) . enumerate ( ) {
297- self_offset += multi_idx[ dim_i] * stride;
298- }
299-
300- let mut other_offset = 0 ;
301- for ( dim_i, & stride) in other_bc_strides. iter ( ) . enumerate ( ) {
302- other_offset += multi_idx[ dim_i] * stride;
303- }
304-
305- let val = self_data[ self_offset] - other_data[ other_offset] ;
306- result_data. push ( val) ;
307- }
308- Tensor :: new ( bc_shape, result_data)
306+ pub fn sub ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
307+ self . binary_op ( other, |a, b| a - b)
309308 }
310309
311310 pub fn sub_inplace ( & mut self , other : & Tensor ) {
@@ -322,48 +321,8 @@ impl Tensor {
322321 } ) ;
323322 }
324323
325- pub fn mul ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
326- let self_shape = self . shape ( ) ;
327- let other_shape = other. shape ( ) ;
328- if !is_broadcastable ( self_shape, other_shape) {
329- return Err ( "The tensor shapes are not compatible for multiplication." ) ;
330- }
331-
332- if self_shape == other_shape {
333- let result_data: Vec < f32 > = self
334- . data
335- . iter ( )
336- . zip ( other. data . iter ( ) )
337- . map ( |( a, b) | a * b)
338- . collect ( ) ;
339- return Tensor :: new ( self_shape. clone ( ) , result_data) ;
340- }
341-
342- let ( bc_shape, self_bc_strides, other_bc_strides) =
343- compute_broadcast_shape_and_strides ( self_shape, other_shape) ;
344-
345- let result_size = bc_shape. iter ( ) . product ( ) ;
346- let mut result_data: Vec < f32 > = Vec :: with_capacity ( result_size) ;
347- let self_data = self . data ( ) ;
348- let other_data = other. data ( ) ;
349-
350- for i in 0 ..result_size {
351- let multi_idx = unravel_index ( i, & bc_shape) ;
352-
353- let mut self_offset = 0 ;
354- for ( dim_i, & stride) in self_bc_strides. iter ( ) . enumerate ( ) {
355- self_offset += multi_idx[ dim_i] * stride;
356- }
357-
358- let mut other_offset = 0 ;
359- for ( dim_i, & stride) in other_bc_strides. iter ( ) . enumerate ( ) {
360- other_offset += multi_idx[ dim_i] * stride;
361- }
362-
363- let val = self_data[ self_offset] * other_data[ other_offset] ;
364- result_data. push ( val) ;
365- }
366- Tensor :: new ( bc_shape, result_data)
324+ pub fn mul ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
325+ self . binary_op ( other, |a, b| a * b)
367326 }
368327
369328 pub fn mul_inplace ( & mut self , other : & Tensor ) {
@@ -380,48 +339,8 @@ impl Tensor {
380339 } ) ;
381340 }
382341
383- pub fn div ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
384- let self_shape = self . shape ( ) ;
385- let other_shape = other. shape ( ) ;
386- if !is_broadcastable ( self_shape, other_shape) {
387- return Err ( "The tensor shapes are not compatible for division." ) ;
388- }
389-
390- if self_shape == other_shape {
391- let result_data: Vec < f32 > = self
392- . data
393- . iter ( )
394- . zip ( other. data . iter ( ) )
395- . map ( |( a, b) | a / b)
396- . collect ( ) ;
397- return Tensor :: new ( self_shape. clone ( ) , result_data) ;
398- }
399-
400- let ( bc_shape, self_bc_strides, other_bc_strides) =
401- compute_broadcast_shape_and_strides ( self_shape, other_shape) ;
402-
403- let result_size = bc_shape. iter ( ) . product ( ) ;
404- let mut result_data: Vec < f32 > = Vec :: with_capacity ( result_size) ;
405- let self_data = self . data ( ) ;
406- let other_data = other. data ( ) ;
407-
408- for i in 0 ..result_size {
409- let multi_idx = unravel_index ( i, & bc_shape) ;
410-
411- let mut self_offset = 0 ;
412- for ( dim_i, & stride) in self_bc_strides. iter ( ) . enumerate ( ) {
413- self_offset += multi_idx[ dim_i] * stride;
414- }
415-
416- let mut other_offset = 0 ;
417- for ( dim_i, & stride) in other_bc_strides. iter ( ) . enumerate ( ) {
418- other_offset += multi_idx[ dim_i] * stride;
419- }
420-
421- let val = self_data[ self_offset] / other_data[ other_offset] ;
422- result_data. push ( val) ;
423- }
424- Tensor :: new ( bc_shape, result_data)
342+ pub fn div ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
343+ self . binary_op ( other, |a, b| a / b)
425344 }
426345
427346 pub fn div_inplace ( & mut self , other : & Tensor ) {
@@ -438,17 +357,21 @@ impl Tensor {
438357 } ) ;
439358 }
440359
441- pub fn matmul ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
360+ pub fn matmul ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
442361 let lhs_shape: & Vec < usize > = self . shape ( ) ;
443362 let rhs_shape: & Vec < usize > = other. shape ( ) ;
444363 if lhs_shape. len ( ) != 2 || rhs_shape. len ( ) != 2 {
445- return Err ( "matmul requires 2D tensors" ) ;
364+ return Err ( TensorError :: BroadcastError (
365+ "matmul requires 2D tensors" . to_string ( ) ,
366+ ) ) ;
446367 }
447368
448369 let ( rows_left, cols_left) = ( lhs_shape[ 0 ] , lhs_shape[ 1 ] ) ;
449370 let ( rows_right, cols_right) = ( rhs_shape[ 0 ] , rhs_shape[ 1 ] ) ;
450371 if cols_left != rows_right {
451- return Err ( "Incompatible shapes for matrix multiplication" ) ;
372+ return Err ( TensorError :: BroadcastError (
373+ "Incompatible shapes for matrix multiplication" . to_string ( ) ,
374+ ) ) ;
452375 }
453376
454377 let lhs_data: & Vec < f32 > = self . data ( ) ;
0 commit comments