Skip to content

Commit b7e64aa

Browse files
authored
Merge pull request #12 from PaytonWebber/feature/tensor-inplace-ops
Feature/tensor inplace ops
2 parents 73f35b9 + a995944 commit b7e64aa

File tree

2 files changed

+286
-8
lines changed

2 files changed

+286
-8
lines changed

tensor/src/lib.rs

Lines changed: 152 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use core::{f32, fmt};
2-
use std::ops::{Add, Div, Index, Mul, Sub};
2+
use std::ops::{Add, AddAssign, Div, DivAssign, Index, Mul, MulAssign, Sub, SubAssign};
33

44
#[derive(Debug, Clone)]
55
pub struct Tensor {
@@ -240,6 +240,20 @@ impl Tensor {
240240
Tensor::new(bc_shape, result_data)
241241
}
242242

243+
pub fn add_inplace(&mut self, other: &Tensor) {
244+
let self_shape = self.shape();
245+
let other_shape = other.shape();
246+
if self_shape != other_shape {
247+
panic!("The tensor shape not compatible for inplace addition")
248+
}
249+
self.data
250+
.iter_mut()
251+
.zip(other.data().iter())
252+
.for_each(|(a, &b)| {
253+
*a += b;
254+
});
255+
}
256+
243257
pub fn sub(&self, other: &Tensor) -> Result<Tensor, &'static str> {
244258
let self_shape = self.shape();
245259
let other_shape = other.shape();
@@ -285,6 +299,20 @@ impl Tensor {
285299
Tensor::new(bc_shape, result_data)
286300
}
287301

302+
pub fn sub_inplace(&mut self, other: &Tensor) {
303+
let self_shape = self.shape();
304+
let other_shape = other.shape();
305+
if self_shape != other_shape {
306+
panic!("The tensor shape not compatible for inplace subtraction")
307+
}
308+
self.data
309+
.iter_mut()
310+
.zip(other.data().iter())
311+
.for_each(|(a, &b)| {
312+
*a -= b;
313+
});
314+
}
315+
288316
pub fn mul(&self, other: &Tensor) -> Result<Tensor, &'static str> {
289317
let self_shape = self.shape();
290318
let other_shape = other.shape();
@@ -329,6 +357,20 @@ impl Tensor {
329357
Tensor::new(bc_shape, result_data)
330358
}
331359

360+
pub fn mul_inplace(&mut self, other: &Tensor) {
361+
let self_shape = self.shape();
362+
let other_shape = other.shape();
363+
if self_shape != other_shape {
364+
panic!("The tensor shape not compatible for inplace multiplication")
365+
}
366+
self.data
367+
.iter_mut()
368+
.zip(other.data().iter())
369+
.for_each(|(a, &b)| {
370+
*a *= b;
371+
});
372+
}
373+
332374
pub fn div(&self, other: &Tensor) -> Result<Tensor, &'static str> {
333375
let self_shape = self.shape();
334376
let other_shape = other.shape();
@@ -373,6 +415,20 @@ impl Tensor {
373415
Tensor::new(bc_shape, result_data)
374416
}
375417

418+
pub fn div_inplace(&mut self, other: &Tensor) {
419+
let self_shape = self.shape();
420+
let other_shape = other.shape();
421+
if self_shape != other_shape {
422+
panic!("The tensor shape not compatible for inplace division")
423+
}
424+
self.data
425+
.iter_mut()
426+
.zip(other.data().iter())
427+
.for_each(|(a, &b)| {
428+
*a /= b;
429+
});
430+
}
431+
376432
pub fn matmul(&self, other: &Tensor) -> Result<Tensor, &'static str> {
377433
let lhs_shape: &Vec<usize> = self.shape();
378434
let rhs_shape: &Vec<usize> = other.shape();
@@ -447,7 +503,7 @@ fn unravel_index(mut i: usize, shape: &[usize]) -> Vec<usize> {
447503

448504
pub fn is_broadcastable(a: &Vec<usize>, b: &Vec<usize>) -> bool {
449505
// This is based on NumPy's rules: https://numpy.org/doc/stable/user/basics.broadcasting.html
450-
for (i, j) in a.into_iter().rev().zip(b.into_iter().rev()) {
506+
for (i, j) in a.iter().rev().zip(b.iter().rev()) {
451507
if *i == 1 || *j == 1 {
452508
continue;
453509
}
@@ -480,26 +536,26 @@ pub fn compute_broadcast_shape_and_strides(
480536
if dim_a != dim_b {
481537
a_bc_strides[ndims - i - 1] = match dim_a {
482538
1 => 0,
483-
_ => a_dims[ndims - i..].into_iter().product(),
539+
_ => a_dims[ndims - i..].iter().product(),
484540
};
485541
b_bc_strides[ndims - i - 1] = match dim_b {
486542
1 => 0,
487-
_ => b_dims[ndims - i..].into_iter().product(),
543+
_ => b_dims[ndims - i..].iter().product(),
488544
};
489545
}
490546
} else {
491547
if dim_a != dim_b {
492548
a_bc_strides[ndims - i - 1] = match dim_a {
493549
1 => 0,
494-
_ => a_dims[ndims - i..].into_iter().product(),
550+
_ => a_dims[ndims - i..].iter().product(),
495551
};
496552
b_bc_strides[ndims - i - 1] = match dim_b {
497553
1 => 0,
498-
_ => b_dims[ndims - i..].into_iter().product(),
554+
_ => b_dims[ndims - i..].iter().product(),
499555
};
500556
} else {
501-
a_bc_strides[ndims - i - 1] = a_dims[ndims - i..].into_iter().product();
502-
b_bc_strides[ndims - i - 1] = b_dims[ndims - i..].into_iter().product();
557+
a_bc_strides[ndims - i - 1] = a_dims[ndims - i..].iter().product();
558+
b_bc_strides[ndims - i - 1] = b_dims[ndims - i..].iter().product();
503559
}
504560
}
505561
bc_shape[ndims - i - 1] = dim_a.max(dim_b);
@@ -641,6 +697,28 @@ impl Add<&Tensor> for f32 {
641697
}
642698
}
643699

700+
impl AddAssign<&Tensor> for Tensor {
701+
fn add_assign(&mut self, rhs: &Tensor) {
702+
if *self.shape() != *rhs.shape() {
703+
panic!("The tensor shape not compatible for inplace addition")
704+
}
705+
self.data
706+
.iter_mut()
707+
.zip(rhs.data().iter())
708+
.for_each(|(a, b)| {
709+
*a += b;
710+
});
711+
}
712+
}
713+
714+
impl AddAssign<f32> for Tensor {
715+
fn add_assign(&mut self, rhs: f32) {
716+
self.data.iter_mut().for_each(|a| {
717+
*a += rhs;
718+
});
719+
}
720+
}
721+
644722
impl Sub<&Tensor> for &Tensor {
645723
type Output = Tensor;
646724

@@ -670,6 +748,28 @@ impl Sub<&Tensor> for f32 {
670748
}
671749
}
672750

751+
impl SubAssign<&Tensor> for Tensor {
752+
fn sub_assign(&mut self, rhs: &Tensor) {
753+
if *self.shape() != *rhs.shape() {
754+
panic!("The tensor shape not compatible for inplace subtraction")
755+
}
756+
self.data
757+
.iter_mut()
758+
.zip(rhs.data().iter())
759+
.for_each(|(a, b)| {
760+
*a -= b;
761+
});
762+
}
763+
}
764+
765+
impl SubAssign<f32> for Tensor {
766+
fn sub_assign(&mut self, rhs: f32) {
767+
self.data.iter_mut().for_each(|a| {
768+
*a -= rhs;
769+
});
770+
}
771+
}
772+
673773
impl Mul<&Tensor> for &Tensor {
674774
type Output = Tensor;
675775

@@ -699,6 +799,28 @@ impl Mul<&Tensor> for f32 {
699799
}
700800
}
701801

802+
impl MulAssign<&Tensor> for Tensor {
803+
fn mul_assign(&mut self, rhs: &Tensor) {
804+
if *self.shape() != *rhs.shape() {
805+
panic!("The tensor shape not compatible for inplace subtraction")
806+
}
807+
self.data
808+
.iter_mut()
809+
.zip(rhs.data().iter())
810+
.for_each(|(a, b)| {
811+
*a *= b;
812+
});
813+
}
814+
}
815+
816+
impl MulAssign<f32> for Tensor {
817+
fn mul_assign(&mut self, rhs: f32) {
818+
self.data.iter_mut().for_each(|a| {
819+
*a *= rhs;
820+
});
821+
}
822+
}
823+
702824
impl Div<&Tensor> for &Tensor {
703825
type Output = Tensor;
704826

@@ -718,3 +840,25 @@ impl Div<f32> for &Tensor {
718840
Tensor::new(self.shape().clone(), result_data).unwrap()
719841
}
720842
}
843+
844+
impl DivAssign<&Tensor> for Tensor {
845+
fn div_assign(&mut self, rhs: &Tensor) {
846+
if *self.shape() != *rhs.shape() {
847+
panic!("The tensor shape not compatible for inplace subtraction")
848+
}
849+
self.data
850+
.iter_mut()
851+
.zip(rhs.data().iter())
852+
.for_each(|(a, b)| {
853+
*a /= b;
854+
});
855+
}
856+
}
857+
858+
impl DivAssign<f32> for Tensor {
859+
fn div_assign(&mut self, rhs: f32) {
860+
self.data.iter_mut().for_each(|a| {
861+
*a /= rhs;
862+
});
863+
}
864+
}

0 commit comments

Comments
 (0)