@@ -365,7 +365,7 @@ fn tensor_broadcasted_addition_operator() {
365365}
366366
367367#[ test]
368- fn tensor_add_inplace ( ) {
368+ fn tensor_add_inplace_method ( ) {
369369 let a_shape = vec ! [ 2 , 3 ] ;
370370 let b_shape = vec ! [ 2 , 3 ] ;
371371 let a_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
@@ -378,6 +378,21 @@ fn tensor_add_inplace() {
378378 assert_eq ! ( expected_data, * a. data( ) ) ;
379379}
380380
381+ #[ test]
382+ fn tensor_add_inplace_operator ( ) {
383+ let a_shape = vec ! [ 2 , 3 ] ;
384+ let b_shape = vec ! [ 2 , 3 ] ;
385+ let a_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
386+ let b_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
387+
388+ let mut a = Tensor :: new ( a_shape, a_data) . unwrap ( ) ;
389+ let b = Tensor :: new ( b_shape, b_data) . unwrap ( ) ;
390+ a += & b;
391+
392+ let expected_data: Vec < f32 > = vec ! [ 2.0 , 4.0 , 6.0 , 8.0 , 10.0 , 12.0 ] ;
393+ assert_eq ! ( expected_data, * a. data( ) ) ;
394+ }
395+
381396#[ test]
382397fn tensor_subtraction_method ( ) {
383398 let shape = vec ! [ 4 , 2 ] ;
@@ -450,7 +465,7 @@ fn tensor_broadcasted_subtraction_operator() {
450465}
451466
452467#[ test]
453- fn tensor_sub_inplace ( ) {
468+ fn tensor_sub_inplace_method ( ) {
454469 let a_shape = vec ! [ 2 , 3 ] ;
455470 let b_shape = vec ! [ 2 , 3 ] ;
456471 let a_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
@@ -463,6 +478,21 @@ fn tensor_sub_inplace() {
463478 assert_eq ! ( expected_data, * a. data( ) ) ;
464479}
465480
481+ #[ test]
482+ fn tensor_sub_inplace_operator ( ) {
483+ let a_shape = vec ! [ 2 , 3 ] ;
484+ let b_shape = vec ! [ 2 , 3 ] ;
485+ let a_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
486+ let b_data: Vec < f32 > = vec ! [ 2.0 , 2.0 , 2.0 , 2.0 , 2.0 , 2.0 ] ;
487+
488+ let mut a = Tensor :: new ( a_shape, a_data) . unwrap ( ) ;
489+ let b = Tensor :: new ( b_shape, b_data) . unwrap ( ) ;
490+ a -= & b;
491+
492+ let expected_data: Vec < f32 > = vec ! [ -1.0 , 0.0 , 1.0 , 2.0 , 3.0 , 4.0 ] ;
493+ assert_eq ! ( expected_data, * a. data( ) ) ;
494+ }
495+
466496#[ test]
467497fn tensor_mul_method ( ) {
468498 let a_shape = vec ! [ 1 , 3 ] ;
@@ -528,7 +558,7 @@ fn tensor_broadcasted_mul_method() {
528558}
529559
530560#[ test]
531- fn tensor_mul_inplace ( ) {
561+ fn tensor_mul_inplace_method ( ) {
532562 let a_shape = vec ! [ 2 , 3 ] ;
533563 let b_shape = vec ! [ 2 , 3 ] ;
534564 let a_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
@@ -542,6 +572,21 @@ fn tensor_mul_inplace() {
542572 assert_eq ! ( expected_data, * a. data( ) ) ;
543573}
544574
575+ #[ test]
576+ fn tensor_mul_inplace_operator ( ) {
577+ let a_shape = vec ! [ 2 , 3 ] ;
578+ let b_shape = vec ! [ 2 , 3 ] ;
579+ let a_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
580+ let b_data: Vec < f32 > = vec ! [ 2.0 , 2.0 , 2.0 , 2.0 , 2.0 , 2.0 ] ;
581+
582+ let mut a = Tensor :: new ( a_shape, a_data) . unwrap ( ) ;
583+ let b = Tensor :: new ( b_shape, b_data) . unwrap ( ) ;
584+ a *= & b;
585+
586+ let expected_data: Vec < f32 > = vec ! [ 2.0 , 4.0 , 6.0 , 8.0 , 10.0 , 12.0 ] ;
587+ assert_eq ! ( expected_data, * a. data( ) ) ;
588+ }
589+
545590#[ test]
546591fn tensor_div_method ( ) {
547592 let a_shape = vec ! [ 1 , 3 ] ;
@@ -604,7 +649,7 @@ fn tensor_broadcasted_div_method() {
604649}
605650
606651#[ test]
607- fn tensor_div_inplace ( ) {
652+ fn tensor_div_inplace_method ( ) {
608653 let a_shape = vec ! [ 2 , 3 ] ;
609654 let b_shape = vec ! [ 2 , 3 ] ;
610655 let a_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
@@ -618,6 +663,21 @@ fn tensor_div_inplace() {
618663 assert_eq ! ( expected_data, * a. data( ) ) ;
619664}
620665
666+ #[ test]
667+ fn tensor_div_inplace_operator ( ) {
668+ let a_shape = vec ! [ 2 , 3 ] ;
669+ let b_shape = vec ! [ 2 , 3 ] ;
670+ let a_data: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
671+ let b_data: Vec < f32 > = vec ! [ 2.0 , 2.0 , 2.0 , 2.0 , 2.0 , 2.0 ] ;
672+
673+ let mut a = Tensor :: new ( a_shape, a_data) . unwrap ( ) ;
674+ let b = Tensor :: new ( b_shape, b_data) . unwrap ( ) ;
675+ a /= & b;
676+
677+ let expected_data: Vec < f32 > = vec ! [ 0.5 , 1.0 , 1.5 , 2.0 , 2.5 , 3.0 ] ;
678+ assert_eq ! ( expected_data, * a. data( ) ) ;
679+ }
680+
621681#[ test]
622682fn tensor_matmul ( ) {
623683 // A is 2x3:
0 commit comments