@@ -285,76 +285,55 @@ impl Tensor {
285285 Tensor :: new ( bc_shape, result_data)
286286 }
287287
288- pub fn add ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
289- self . binary_op ( other, |a, b| a + b)
290- }
291-
292- pub fn add_inplace ( & mut self , other : & Tensor ) {
288+ fn binary_op_inplace < F > ( & mut self , other : & Tensor , op : F )
289+ where
290+ F : Fn ( & mut f32 , f32 ) ,
291+ {
293292 let self_shape = self . shape ( ) ;
294293 let other_shape = other. shape ( ) ;
294+
295295 if self_shape != other_shape {
296- panic ! ( "The tensor shape not compatible for inplace addition" )
296+ panic ! ( "Shapes not compatible for in-place operation" ) ;
297297 }
298+
298299 self . data
299300 . iter_mut ( )
300- . zip ( other. data ( ) . iter ( ) )
301+ . zip ( other. data . iter ( ) )
301302 . for_each ( |( a, & b) | {
302- * a += b ;
303+ op ( a , b ) ;
303304 } ) ;
304305 }
305306
307+ pub fn add ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
308+ self . binary_op ( other, |a, b| a + b)
309+ }
310+
311+ pub fn add_inplace ( & mut self , other : & Tensor ) {
312+ self . binary_op_inplace ( other, |a, b| * a += b) ;
313+ }
314+
306315 pub fn sub ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
307316 self . binary_op ( other, |a, b| a - b)
308317 }
309318
310319 pub fn sub_inplace ( & mut self , other : & Tensor ) {
311- let self_shape = self . shape ( ) ;
312- let other_shape = other. shape ( ) ;
313- if self_shape != other_shape {
314- panic ! ( "The tensor shape not compatible for inplace subtraction" )
315- }
316- self . data
317- . iter_mut ( )
318- . zip ( other. data ( ) . iter ( ) )
319- . for_each ( |( a, & b) | {
320- * a -= b;
321- } ) ;
320+ self . binary_op_inplace ( other, |a, b| * a -= b) ;
322321 }
323322
324323 pub fn mul ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
325324 self . binary_op ( other, |a, b| a * b)
326325 }
327326
328327 pub fn mul_inplace ( & mut self , other : & Tensor ) {
329- let self_shape = self . shape ( ) ;
330- let other_shape = other. shape ( ) ;
331- if self_shape != other_shape {
332- panic ! ( "The tensor shape not compatible for inplace multiplication" )
333- }
334- self . data
335- . iter_mut ( )
336- . zip ( other. data ( ) . iter ( ) )
337- . for_each ( |( a, & b) | {
338- * a *= b;
339- } ) ;
328+ self . binary_op_inplace ( other, |a, b| * a *= b) ;
340329 }
341330
342331 pub fn div ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
343332 self . binary_op ( other, |a, b| a / b)
344333 }
345334
346335 pub fn div_inplace ( & mut self , other : & Tensor ) {
347- let self_shape = self . shape ( ) ;
348- let other_shape = other. shape ( ) ;
349- if self_shape != other_shape {
350- panic ! ( "The tensor shape not compatible for inplace division" )
351- }
352- self . data
353- . iter_mut ( )
354- . zip ( other. data ( ) . iter ( ) )
355- . for_each ( |( a, & b) | {
356- * a /= b;
357- } ) ;
336+ self . binary_op_inplace ( other, |a, b| * a /= b) ;
358337 }
359338
360339 pub fn matmul ( & self , other : & Tensor ) -> Result < Tensor , TensorError > {
0 commit comments