Skip to content

Commit 5f3fce7

Browse files
committed
refactor inplace binary ops to use a generic helper function
1 parent 5d68c8a commit 5f3fce7

File tree

1 file changed

+20
-41
lines changed

1 file changed

+20
-41
lines changed

tensor/src/lib.rs

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -285,76 +285,55 @@ impl Tensor {
285285
Tensor::new(bc_shape, result_data)
286286
}
287287

288-
pub fn add(&self, other: &Tensor) -> Result<Tensor, TensorError> {
289-
self.binary_op(other, |a, b| a + b)
290-
}
291-
292-
pub fn add_inplace(&mut self, other: &Tensor) {
288+
fn binary_op_inplace<F>(&mut self, other: &Tensor, op: F)
289+
where
290+
F: Fn(&mut f32, f32),
291+
{
293292
let self_shape = self.shape();
294293
let other_shape = other.shape();
294+
295295
if self_shape != other_shape {
296-
panic!("The tensor shape not compatible for inplace addition")
296+
panic!("Shapes not compatible for in-place operation");
297297
}
298+
298299
self.data
299300
.iter_mut()
300-
.zip(other.data().iter())
301+
.zip(other.data.iter())
301302
.for_each(|(a, &b)| {
302-
*a += b;
303+
op(a, b);
303304
});
304305
}
305306

307+
pub fn add(&self, other: &Tensor) -> Result<Tensor, TensorError> {
308+
self.binary_op(other, |a, b| a + b)
309+
}
310+
311+
pub fn add_inplace(&mut self, other: &Tensor) {
312+
self.binary_op_inplace(other, |a, b| *a += b);
313+
}
314+
306315
pub fn sub(&self, other: &Tensor) -> Result<Tensor, TensorError> {
307316
self.binary_op(other, |a, b| a - b)
308317
}
309318

310319
pub fn sub_inplace(&mut self, other: &Tensor) {
311-
let self_shape = self.shape();
312-
let other_shape = other.shape();
313-
if self_shape != other_shape {
314-
panic!("The tensor shape not compatible for inplace subtraction")
315-
}
316-
self.data
317-
.iter_mut()
318-
.zip(other.data().iter())
319-
.for_each(|(a, &b)| {
320-
*a -= b;
321-
});
320+
self.binary_op_inplace(other, |a, b| *a -= b);
322321
}
323322

324323
pub fn mul(&self, other: &Tensor) -> Result<Tensor, TensorError> {
325324
self.binary_op(other, |a, b| a * b)
326325
}
327326

328327
pub fn mul_inplace(&mut self, other: &Tensor) {
329-
let self_shape = self.shape();
330-
let other_shape = other.shape();
331-
if self_shape != other_shape {
332-
panic!("The tensor shape not compatible for inplace multiplication")
333-
}
334-
self.data
335-
.iter_mut()
336-
.zip(other.data().iter())
337-
.for_each(|(a, &b)| {
338-
*a *= b;
339-
});
328+
self.binary_op_inplace(other, |a, b| *a *= b);
340329
}
341330

342331
pub fn div(&self, other: &Tensor) -> Result<Tensor, TensorError> {
343332
self.binary_op(other, |a, b| a / b)
344333
}
345334

346335
pub fn div_inplace(&mut self, other: &Tensor) {
347-
let self_shape = self.shape();
348-
let other_shape = other.shape();
349-
if self_shape != other_shape {
350-
panic!("The tensor shape not compatible for inplace division")
351-
}
352-
self.data
353-
.iter_mut()
354-
.zip(other.data().iter())
355-
.for_each(|(a, &b)| {
356-
*a /= b;
357-
});
336+
self.binary_op_inplace(other, |a, b| *a /= b);
358337
}
359338

360339
pub fn matmul(&self, other: &Tensor) -> Result<Tensor, TensorError> {

0 commit comments

Comments
 (0)