Skip to content

Commit a995944

Browse files
committed
add trait implementations for binary assign operations for tensor and f32
1 parent f59f1f1 commit a995944

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

tensor/src/lib.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,14 @@ impl AddAssign<&Tensor> for Tensor {
711711
}
712712
}
713713

714+
impl AddAssign<f32> for Tensor {
715+
fn add_assign(&mut self, rhs: f32) {
716+
self.data.iter_mut().for_each(|a| {
717+
*a += rhs;
718+
});
719+
}
720+
}
721+
714722
impl Sub<&Tensor> for &Tensor {
715723
type Output = Tensor;
716724

@@ -754,6 +762,14 @@ impl SubAssign<&Tensor> for Tensor {
754762
}
755763
}
756764

765+
impl SubAssign<f32> for Tensor {
766+
fn sub_assign(&mut self, rhs: f32) {
767+
self.data.iter_mut().for_each(|a| {
768+
*a -= rhs;
769+
});
770+
}
771+
}
772+
757773
impl Mul<&Tensor> for &Tensor {
758774
type Output = Tensor;
759775

@@ -797,6 +813,14 @@ impl MulAssign<&Tensor> for Tensor {
797813
}
798814
}
799815

816+
impl MulAssign<f32> for Tensor {
817+
fn mul_assign(&mut self, rhs: f32) {
818+
self.data.iter_mut().for_each(|a| {
819+
*a *= rhs;
820+
});
821+
}
822+
}
823+
800824
impl Div<&Tensor> for &Tensor {
801825
type Output = Tensor;
802826

@@ -830,3 +854,11 @@ impl DivAssign<&Tensor> for Tensor {
830854
});
831855
}
832856
}
857+
858+
impl DivAssign<f32> for Tensor {
859+
fn div_assign(&mut self, rhs: f32) {
860+
self.data.iter_mut().for_each(|a| {
861+
*a /= rhs;
862+
});
863+
}
864+
}

tensor/tests/tensor_core_test.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,14 @@ fn tensor_add_inplace_operator() {
387387

388388
let mut a = Tensor::new(a_shape, a_data).unwrap();
389389
let b = Tensor::new(b_shape, b_data).unwrap();
390-
a += &b;
391390

391+
a += &b;
392392
let expected_data: Vec<f32> = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0];
393393
assert_eq!(expected_data, *a.data());
394+
395+
a += 4.0;
396+
let expected_data: Vec<f32> = vec![6.0, 8.0, 10.0, 12.0, 14.0, 16.0];
397+
assert_eq!(expected_data, *a.data());
394398
}
395399

396400
#[test]
@@ -487,10 +491,14 @@ fn tensor_sub_inplace_operator() {
487491

488492
let mut a = Tensor::new(a_shape, a_data).unwrap();
489493
let b = Tensor::new(b_shape, b_data).unwrap();
490-
a -= &b;
491494

495+
a -= &b;
492496
let expected_data: Vec<f32> = vec![-1.0, 0.0, 1.0, 2.0, 3.0, 4.0];
493497
assert_eq!(expected_data, *a.data());
498+
499+
a -= 2.0;
500+
let expected_data: Vec<f32> = vec![-3.0, -2.0, -1.0, 0.0, 1.0, 2.0];
501+
assert_eq!(expected_data, *a.data());
494502
}
495503

496504
#[test]
@@ -581,10 +589,14 @@ fn tensor_mul_inplace_operator() {
581589

582590
let mut a = Tensor::new(a_shape, a_data).unwrap();
583591
let b = Tensor::new(b_shape, b_data).unwrap();
584-
a *= &b;
585592

593+
a *= &b;
586594
let expected_data: Vec<f32> = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0];
587595
assert_eq!(expected_data, *a.data());
596+
597+
a *= 2.0;
598+
let expected_data: Vec<f32> = vec![4.0, 8.0, 12.0, 16.0, 20.0, 24.0];
599+
assert_eq!(expected_data, *a.data());
588600
}
589601

590602
#[test]
@@ -672,10 +684,14 @@ fn tensor_div_inplace_operator() {
672684

673685
let mut a = Tensor::new(a_shape, a_data).unwrap();
674686
let b = Tensor::new(b_shape, b_data).unwrap();
675-
a /= &b;
676687

688+
a /= &b;
677689
let expected_data: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
678690
assert_eq!(expected_data, *a.data());
691+
692+
a /= 0.5;
693+
let expected_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
694+
assert_eq!(expected_data, *a.data());
679695
}
680696

681697
#[test]

0 commit comments

Comments
 (0)