11use core:: { f32, fmt} ;
2- use std:: ops:: { Add , Div , Index , Mul , Sub } ;
2+ use std:: ops:: { Add , AddAssign , Div , DivAssign , Index , Mul , MulAssign , Sub , SubAssign } ;
33
44#[ derive( Debug , Clone ) ]
55pub struct Tensor {
@@ -240,6 +240,20 @@ impl Tensor {
240240 Tensor :: new ( bc_shape, result_data)
241241 }
242242
243+ pub fn add_inplace ( & mut self , other : & Tensor ) {
244+ let self_shape = self . shape ( ) ;
245+ let other_shape = other. shape ( ) ;
246+ if self_shape != other_shape {
247+ panic ! ( "The tensor shape not compatible for inplace addition" )
248+ }
249+ self . data
250+ . iter_mut ( )
251+ . zip ( other. data ( ) . iter ( ) )
252+ . for_each ( |( a, & b) | {
253+ * a += b;
254+ } ) ;
255+ }
256+
243257 pub fn sub ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
244258 let self_shape = self . shape ( ) ;
245259 let other_shape = other. shape ( ) ;
@@ -285,6 +299,20 @@ impl Tensor {
285299 Tensor :: new ( bc_shape, result_data)
286300 }
287301
302+ pub fn sub_inplace ( & mut self , other : & Tensor ) {
303+ let self_shape = self . shape ( ) ;
304+ let other_shape = other. shape ( ) ;
305+ if self_shape != other_shape {
306+ panic ! ( "The tensor shape not compatible for inplace subtraction" )
307+ }
308+ self . data
309+ . iter_mut ( )
310+ . zip ( other. data ( ) . iter ( ) )
311+ . for_each ( |( a, & b) | {
312+ * a -= b;
313+ } ) ;
314+ }
315+
288316 pub fn mul ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
289317 let self_shape = self . shape ( ) ;
290318 let other_shape = other. shape ( ) ;
@@ -329,6 +357,20 @@ impl Tensor {
329357 Tensor :: new ( bc_shape, result_data)
330358 }
331359
360+ pub fn mul_inplace ( & mut self , other : & Tensor ) {
361+ let self_shape = self . shape ( ) ;
362+ let other_shape = other. shape ( ) ;
363+ if self_shape != other_shape {
364+ panic ! ( "The tensor shape not compatible for inplace multiplication" )
365+ }
366+ self . data
367+ . iter_mut ( )
368+ . zip ( other. data ( ) . iter ( ) )
369+ . for_each ( |( a, & b) | {
370+ * a *= b;
371+ } ) ;
372+ }
373+
332374 pub fn div ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
333375 let self_shape = self . shape ( ) ;
334376 let other_shape = other. shape ( ) ;
@@ -373,6 +415,20 @@ impl Tensor {
373415 Tensor :: new ( bc_shape, result_data)
374416 }
375417
418+ pub fn div_inplace ( & mut self , other : & Tensor ) {
419+ let self_shape = self . shape ( ) ;
420+ let other_shape = other. shape ( ) ;
421+ if self_shape != other_shape {
422+ panic ! ( "The tensor shape not compatible for inplace division" )
423+ }
424+ self . data
425+ . iter_mut ( )
426+ . zip ( other. data ( ) . iter ( ) )
427+ . for_each ( |( a, & b) | {
428+ * a /= b;
429+ } ) ;
430+ }
431+
376432 pub fn matmul ( & self , other : & Tensor ) -> Result < Tensor , & ' static str > {
377433 let lhs_shape: & Vec < usize > = self . shape ( ) ;
378434 let rhs_shape: & Vec < usize > = other. shape ( ) ;
@@ -447,7 +503,7 @@ fn unravel_index(mut i: usize, shape: &[usize]) -> Vec<usize> {
447503
448504pub fn is_broadcastable ( a : & Vec < usize > , b : & Vec < usize > ) -> bool {
449505 // This is based on NumPy's rules: https://numpy.org/doc/stable/user/basics.broadcasting.html
450- for ( i, j) in a. into_iter ( ) . rev ( ) . zip ( b. into_iter ( ) . rev ( ) ) {
506+ for ( i, j) in a. iter ( ) . rev ( ) . zip ( b. iter ( ) . rev ( ) ) {
451507 if * i == 1 || * j == 1 {
452508 continue ;
453509 }
@@ -480,26 +536,26 @@ pub fn compute_broadcast_shape_and_strides(
480536 if dim_a != dim_b {
481537 a_bc_strides[ ndims - i - 1 ] = match dim_a {
482538 1 => 0 ,
483- _ => a_dims[ ndims - i..] . into_iter ( ) . product ( ) ,
539+ _ => a_dims[ ndims - i..] . iter ( ) . product ( ) ,
484540 } ;
485541 b_bc_strides[ ndims - i - 1 ] = match dim_b {
486542 1 => 0 ,
487- _ => b_dims[ ndims - i..] . into_iter ( ) . product ( ) ,
543+ _ => b_dims[ ndims - i..] . iter ( ) . product ( ) ,
488544 } ;
489545 }
490546 } else {
491547 if dim_a != dim_b {
492548 a_bc_strides[ ndims - i - 1 ] = match dim_a {
493549 1 => 0 ,
494- _ => a_dims[ ndims - i..] . into_iter ( ) . product ( ) ,
550+ _ => a_dims[ ndims - i..] . iter ( ) . product ( ) ,
495551 } ;
496552 b_bc_strides[ ndims - i - 1 ] = match dim_b {
497553 1 => 0 ,
498- _ => b_dims[ ndims - i..] . into_iter ( ) . product ( ) ,
554+ _ => b_dims[ ndims - i..] . iter ( ) . product ( ) ,
499555 } ;
500556 } else {
501- a_bc_strides[ ndims - i - 1 ] = a_dims[ ndims - i..] . into_iter ( ) . product ( ) ;
502- b_bc_strides[ ndims - i - 1 ] = b_dims[ ndims - i..] . into_iter ( ) . product ( ) ;
557+ a_bc_strides[ ndims - i - 1 ] = a_dims[ ndims - i..] . iter ( ) . product ( ) ;
558+ b_bc_strides[ ndims - i - 1 ] = b_dims[ ndims - i..] . iter ( ) . product ( ) ;
503559 }
504560 }
505561 bc_shape[ ndims - i - 1 ] = dim_a. max ( dim_b) ;
@@ -641,6 +697,28 @@ impl Add<&Tensor> for f32 {
641697 }
642698}
643699
700+ impl AddAssign < & Tensor > for Tensor {
701+ fn add_assign ( & mut self , rhs : & Tensor ) {
702+ if * self . shape ( ) != * rhs. shape ( ) {
703+ panic ! ( "The tensor shape not compatible for inplace addition" )
704+ }
705+ self . data
706+ . iter_mut ( )
707+ . zip ( rhs. data ( ) . iter ( ) )
708+ . for_each ( |( a, b) | {
709+ * a += b;
710+ } ) ;
711+ }
712+ }
713+
714+ impl AddAssign < f32 > for Tensor {
715+ fn add_assign ( & mut self , rhs : f32 ) {
716+ self . data . iter_mut ( ) . for_each ( |a| {
717+ * a += rhs;
718+ } ) ;
719+ }
720+ }
721+
644722impl Sub < & Tensor > for & Tensor {
645723 type Output = Tensor ;
646724
@@ -670,6 +748,28 @@ impl Sub<&Tensor> for f32 {
670748 }
671749}
672750
751+ impl SubAssign < & Tensor > for Tensor {
752+ fn sub_assign ( & mut self , rhs : & Tensor ) {
753+ if * self . shape ( ) != * rhs. shape ( ) {
754+ panic ! ( "The tensor shape not compatible for inplace subtraction" )
755+ }
756+ self . data
757+ . iter_mut ( )
758+ . zip ( rhs. data ( ) . iter ( ) )
759+ . for_each ( |( a, b) | {
760+ * a -= b;
761+ } ) ;
762+ }
763+ }
764+
765+ impl SubAssign < f32 > for Tensor {
766+ fn sub_assign ( & mut self , rhs : f32 ) {
767+ self . data . iter_mut ( ) . for_each ( |a| {
768+ * a -= rhs;
769+ } ) ;
770+ }
771+ }
772+
673773impl Mul < & Tensor > for & Tensor {
674774 type Output = Tensor ;
675775
@@ -699,6 +799,28 @@ impl Mul<&Tensor> for f32 {
699799 }
700800}
701801
802+ impl MulAssign < & Tensor > for Tensor {
803+ fn mul_assign ( & mut self , rhs : & Tensor ) {
804+ if * self . shape ( ) != * rhs. shape ( ) {
805+ panic ! ( "The tensor shape not compatible for inplace subtraction" )
806+ }
807+ self . data
808+ . iter_mut ( )
809+ . zip ( rhs. data ( ) . iter ( ) )
810+ . for_each ( |( a, b) | {
811+ * a *= b;
812+ } ) ;
813+ }
814+ }
815+
816+ impl MulAssign < f32 > for Tensor {
817+ fn mul_assign ( & mut self , rhs : f32 ) {
818+ self . data . iter_mut ( ) . for_each ( |a| {
819+ * a *= rhs;
820+ } ) ;
821+ }
822+ }
823+
702824impl Div < & Tensor > for & Tensor {
703825 type Output = Tensor ;
704826
@@ -718,3 +840,25 @@ impl Div<f32> for &Tensor {
718840 Tensor :: new ( self . shape ( ) . clone ( ) , result_data) . unwrap ( )
719841 }
720842}
843+
844+ impl DivAssign < & Tensor > for Tensor {
845+ fn div_assign ( & mut self , rhs : & Tensor ) {
846+ if * self . shape ( ) != * rhs. shape ( ) {
847+ panic ! ( "The tensor shape not compatible for inplace subtraction" )
848+ }
849+ self . data
850+ . iter_mut ( )
851+ . zip ( rhs. data ( ) . iter ( ) )
852+ . for_each ( |( a, b) | {
853+ * a /= b;
854+ } ) ;
855+ }
856+ }
857+
858+ impl DivAssign < f32 > for Tensor {
859+ fn div_assign ( & mut self , rhs : f32 ) {
860+ self . data . iter_mut ( ) . for_each ( |a| {
861+ * a /= rhs;
862+ } ) ;
863+ }
864+ }
0 commit comments