Skip to content

Commit f59f1f1

Browse files
committed
add trait implementations for binary assign operations for tensors
1 parent 7d9da09 commit f59f1f1

File tree

2 files changed

+125
-9
lines changed

2 files changed

+125
-9
lines changed

tensor/src/lib.rs

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use core::{f32, fmt};
2-
use std::ops::{Add, Div, Index, Mul, Sub};
2+
use std::ops::{Add, AddAssign, Div, DivAssign, Index, Mul, MulAssign, Sub, SubAssign};
33

44
#[derive(Debug, Clone)]
55
pub struct Tensor {
@@ -244,7 +244,7 @@ impl Tensor {
244244
let self_shape = self.shape();
245245
let other_shape = other.shape();
246246
if self_shape != other_shape {
247-
return;
247+
panic!("The tensor shape not compatible for inplace addition")
248248
}
249249
self.data
250250
.iter_mut()
@@ -303,7 +303,7 @@ impl Tensor {
303303
let self_shape = self.shape();
304304
let other_shape = other.shape();
305305
if self_shape != other_shape {
306-
return;
306+
panic!("The tensor shape not compatible for inplace subtraction")
307307
}
308308
self.data
309309
.iter_mut()
@@ -361,7 +361,7 @@ impl Tensor {
361361
let self_shape = self.shape();
362362
let other_shape = other.shape();
363363
if self_shape != other_shape {
364-
return;
364+
panic!("The tensor shape not compatible for inplace multiplication")
365365
}
366366
self.data
367367
.iter_mut()
@@ -419,7 +419,7 @@ impl Tensor {
419419
let self_shape = self.shape();
420420
let other_shape = other.shape();
421421
if self_shape != other_shape {
422-
return;
422+
panic!("The tensor shape not compatible for inplace division")
423423
}
424424
self.data
425425
.iter_mut()
@@ -697,6 +697,20 @@ impl Add<&Tensor> for f32 {
697697
}
698698
}
699699

700+
impl AddAssign<&Tensor> for Tensor {
701+
fn add_assign(&mut self, rhs: &Tensor) {
702+
if *self.shape() != *rhs.shape() {
703+
panic!("The tensor shape not compatible for inplace addition")
704+
}
705+
self.data
706+
.iter_mut()
707+
.zip(rhs.data().iter())
708+
.for_each(|(a, b)| {
709+
*a += b;
710+
});
711+
}
712+
}
713+
700714
impl Sub<&Tensor> for &Tensor {
701715
type Output = Tensor;
702716

@@ -726,6 +740,20 @@ impl Sub<&Tensor> for f32 {
726740
}
727741
}
728742

743+
impl SubAssign<&Tensor> for Tensor {
744+
fn sub_assign(&mut self, rhs: &Tensor) {
745+
if *self.shape() != *rhs.shape() {
746+
panic!("The tensor shape not compatible for inplace subtraction")
747+
}
748+
self.data
749+
.iter_mut()
750+
.zip(rhs.data().iter())
751+
.for_each(|(a, b)| {
752+
*a -= b;
753+
});
754+
}
755+
}
756+
729757
impl Mul<&Tensor> for &Tensor {
730758
type Output = Tensor;
731759

@@ -755,6 +783,20 @@ impl Mul<&Tensor> for f32 {
755783
}
756784
}
757785

786+
impl MulAssign<&Tensor> for Tensor {
787+
fn mul_assign(&mut self, rhs: &Tensor) {
788+
if *self.shape() != *rhs.shape() {
789+
panic!("The tensor shape not compatible for inplace subtraction")
790+
}
791+
self.data
792+
.iter_mut()
793+
.zip(rhs.data().iter())
794+
.for_each(|(a, b)| {
795+
*a *= b;
796+
});
797+
}
798+
}
799+
758800
impl Div<&Tensor> for &Tensor {
759801
type Output = Tensor;
760802

@@ -774,3 +816,17 @@ impl Div<f32> for &Tensor {
774816
Tensor::new(self.shape().clone(), result_data).unwrap()
775817
}
776818
}
819+
820+
impl DivAssign<&Tensor> for Tensor {
821+
fn div_assign(&mut self, rhs: &Tensor) {
822+
if *self.shape() != *rhs.shape() {
823+
panic!("The tensor shape not compatible for inplace subtraction")
824+
}
825+
self.data
826+
.iter_mut()
827+
.zip(rhs.data().iter())
828+
.for_each(|(a, b)| {
829+
*a /= b;
830+
});
831+
}
832+
}

tensor/tests/tensor_core_test.rs

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ fn tensor_broadcasted_addition_operator() {
365365
}
366366

367367
#[test]
368-
fn tensor_add_inplace() {
368+
fn tensor_add_inplace_method() {
369369
let a_shape = vec![2, 3];
370370
let b_shape = vec![2, 3];
371371
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
@@ -378,6 +378,21 @@ fn tensor_add_inplace() {
378378
assert_eq!(expected_data, *a.data());
379379
}
380380

381+
#[test]
382+
fn tensor_add_inplace_operator() {
383+
let a_shape = vec![2, 3];
384+
let b_shape = vec![2, 3];
385+
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
386+
let b_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
387+
388+
let mut a = Tensor::new(a_shape, a_data).unwrap();
389+
let b = Tensor::new(b_shape, b_data).unwrap();
390+
a += &b;
391+
392+
let expected_data: Vec<f32> = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0];
393+
assert_eq!(expected_data, *a.data());
394+
}
395+
381396
#[test]
382397
fn tensor_subtraction_method() {
383398
let shape = vec![4, 2];
@@ -450,7 +465,7 @@ fn tensor_broadcasted_subtraction_operator() {
450465
}
451466

452467
#[test]
453-
fn tensor_sub_inplace() {
468+
fn tensor_sub_inplace_method() {
454469
let a_shape = vec![2, 3];
455470
let b_shape = vec![2, 3];
456471
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
@@ -463,6 +478,21 @@ fn tensor_sub_inplace() {
463478
assert_eq!(expected_data, *a.data());
464479
}
465480

481+
#[test]
482+
fn tensor_sub_inplace_operator() {
483+
let a_shape = vec![2, 3];
484+
let b_shape = vec![2, 3];
485+
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
486+
let b_data: Vec<f32> = vec![2.0, 2.0, 2.0, 2.0, 2.0, 2.0];
487+
488+
let mut a = Tensor::new(a_shape, a_data).unwrap();
489+
let b = Tensor::new(b_shape, b_data).unwrap();
490+
a -= &b;
491+
492+
let expected_data: Vec<f32> = vec![-1.0, 0.0, 1.0, 2.0, 3.0, 4.0];
493+
assert_eq!(expected_data, *a.data());
494+
}
495+
466496
#[test]
467497
fn tensor_mul_method() {
468498
let a_shape = vec![1, 3];
@@ -528,7 +558,7 @@ fn tensor_broadcasted_mul_method() {
528558
}
529559

530560
#[test]
531-
fn tensor_mul_inplace() {
561+
fn tensor_mul_inplace_method() {
532562
let a_shape = vec![2, 3];
533563
let b_shape = vec![2, 3];
534564
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
@@ -542,6 +572,21 @@ fn tensor_mul_inplace() {
542572
assert_eq!(expected_data, *a.data());
543573
}
544574

575+
#[test]
576+
fn tensor_mul_inplace_operator() {
577+
let a_shape = vec![2, 3];
578+
let b_shape = vec![2, 3];
579+
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
580+
let b_data: Vec<f32> = vec![2.0, 2.0, 2.0, 2.0, 2.0, 2.0];
581+
582+
let mut a = Tensor::new(a_shape, a_data).unwrap();
583+
let b = Tensor::new(b_shape, b_data).unwrap();
584+
a *= &b;
585+
586+
let expected_data: Vec<f32> = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0];
587+
assert_eq!(expected_data, *a.data());
588+
}
589+
545590
#[test]
546591
fn tensor_div_method() {
547592
let a_shape = vec![1, 3];
@@ -604,7 +649,7 @@ fn tensor_broadcasted_div_method() {
604649
}
605650

606651
#[test]
607-
fn tensor_div_inplace() {
652+
fn tensor_div_inplace_method() {
608653
let a_shape = vec![2, 3];
609654
let b_shape = vec![2, 3];
610655
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
@@ -618,6 +663,21 @@ fn tensor_div_inplace() {
618663
assert_eq!(expected_data, *a.data());
619664
}
620665

666+
#[test]
667+
fn tensor_div_inplace_operator() {
668+
let a_shape = vec![2, 3];
669+
let b_shape = vec![2, 3];
670+
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
671+
let b_data: Vec<f32> = vec![2.0, 2.0, 2.0, 2.0, 2.0, 2.0];
672+
673+
let mut a = Tensor::new(a_shape, a_data).unwrap();
674+
let b = Tensor::new(b_shape, b_data).unwrap();
675+
a /= &b;
676+
677+
let expected_data: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
678+
assert_eq!(expected_data, *a.data());
679+
}
680+
621681
#[test]
622682
fn tensor_matmul() {
623683
// A is 2x3:

0 commit comments

Comments
 (0)