11use super :: broadcast:: { compute_broadcast_shape_and_strides, is_broadcastable} ;
22use crate :: core:: utils:: unravel_index;
3- use crate :: core:: TensorError ;
3+ use crate :: core:: { TensorError , Numeric } ;
44use crate :: Tensor ;
55
6- impl Tensor {
7- fn binary_op < F > ( & self , other : & Tensor , op : F ) -> Result < Tensor , TensorError >
6+ impl < T : Numeric > Tensor < T > {
7+ fn binary_op < F > ( & self , other : & Tensor < T > , op : F ) -> Result < Tensor < T > , TensorError >
88 where
9- F : Fn ( f32 , f32 ) -> f32 ,
9+ F : Fn ( T , T ) -> T ,
1010 {
1111 let self_shape = self . shape ( ) ;
1212 let other_shape = other. shape ( ) ;
@@ -17,7 +17,7 @@ impl Tensor {
1717 }
1818
1919 if self_shape == other_shape {
20- let result_data: Vec < f32 > = self
20+ let result_data: Vec < T > = self
2121 . data
2222 . iter ( )
2323 . zip ( other. data . iter ( ) )
@@ -32,7 +32,7 @@ impl Tensor {
3232 let self_data = self . data ( ) ;
3333 let other_data = other. data ( ) ;
3434 let result_size: usize = bc_shape. iter ( ) . product ( ) ;
35- let mut result_data: Vec < f32 > = Vec :: with_capacity ( result_size) ;
35+ let mut result_data: Vec < T > = Vec :: with_capacity ( result_size) ;
3636
3737 for i in 0 ..result_size {
3838 let multi_idx = unravel_index ( i, & bc_shape) ;
@@ -50,9 +50,9 @@ impl Tensor {
5050 Tensor :: new ( bc_shape, result_data)
5151 }
5252
53- fn binary_op_inplace < F > ( & mut self , other : & Tensor , op : F )
53+ fn binary_op_inplace < F > ( & mut self , other : & Tensor < T > , op : F )
5454 where
55- F : Fn ( & mut f32 , f32 ) ,
55+ F : Fn ( & mut T , T ) ,
5656 {
5757 let self_shape = self . shape ( ) ;
5858 let other_shape = other. shape ( ) ;
@@ -69,39 +69,39 @@ impl Tensor {
6969 } ) ;
7070 }
7171
72- pub fn add ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
72+ pub fn add ( & self , other : & Tensor < T > ) -> Result < Tensor < T > , TensorError > {
7373 self . binary_op ( other, |a, b| a + b)
7474 }
7575
76- pub fn add_inplace ( & mut self , other : & Tensor ) {
76+ pub fn add_inplace ( & mut self , other : & Tensor < T > ) {
7777 self . binary_op_inplace ( other, |a, b| * a += b) ;
7878 }
7979
80- pub fn sub ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
80+ pub fn sub ( & self , other : & Tensor < T > ) -> Result < Tensor < T > , TensorError > {
8181 self . binary_op ( other, |a, b| a - b)
8282 }
8383
84- pub fn sub_inplace ( & mut self , other : & Tensor ) {
84+ pub fn sub_inplace ( & mut self , other : & Tensor < T > ) {
8585 self . binary_op_inplace ( other, |a, b| * a -= b) ;
8686 }
8787
88- pub fn mul ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
88+ pub fn mul ( & self , other : & Tensor < T > ) -> Result < Tensor < T > , TensorError > {
8989 self . binary_op ( other, |a, b| a * b)
9090 }
9191
92- pub fn mul_inplace ( & mut self , other : & Tensor ) {
92+ pub fn mul_inplace ( & mut self , other : & Tensor < T > ) {
9393 self . binary_op_inplace ( other, |a, b| * a *= b) ;
9494 }
9595
96- pub fn div ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
96+ pub fn div ( & self , other : & Tensor < T > ) -> Result < Tensor < T > , TensorError > {
9797 self . binary_op ( other, |a, b| a / b)
9898 }
9999
100- pub fn div_inplace ( & mut self , other : & Tensor ) {
100+ pub fn div_inplace ( & mut self , other : & Tensor < T > ) {
101101 self . binary_op_inplace ( other, |a, b| * a /= b) ;
102102 }
103103
104- pub fn matmul ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
104+ pub fn matmul ( & self , other : & Tensor < T > ) -> Result < Tensor < T > , TensorError > {
105105 let lhs_shape = self . shape ( ) ;
106106 let rhs_shape = other. shape ( ) ;
107107 if lhs_shape. len ( ) != 2 || rhs_shape. len ( ) != 2 {
@@ -120,7 +120,7 @@ impl Tensor {
120120
121121 let lhs_data = self . data ( ) ;
122122 let rhs_data = other. data ( ) ;
123- let mut result_data: Vec < f32 > = vec ! [ 0.0_f32 ; rows_left * cols_right] ;
123+ let mut result_data: Vec < T > = vec ! [ T :: zero ( ) ; rows_left * cols_right] ;
124124 for i in 0 ..rows_left {
125125 for k in 0 ..cols_left {
126126 for j in 0 ..cols_right {
@@ -129,6 +129,6 @@ impl Tensor {
129129 }
130130 }
131131 }
132- Ok ( Tensor :: new ( vec ! [ rows_left, cols_right] , result_data) . unwrap ( ) )
132+ Tensor :: new ( vec ! [ rows_left, cols_right] , result_data)
133133 }
134134}
0 commit comments