File tree Expand file tree Collapse file tree 1 file changed +19
-22
lines changed
Expand file tree Collapse file tree 1 file changed +19
-22
lines changed Original file line number Diff line number Diff line change @@ -206,35 +206,32 @@ impl Tensor {
206206
207207 /* UNARY OPS */
208208
209+ fn unary_op < F > ( & self , op : F ) -> Tensor
210+ where
211+ F : Fn ( f32 ) -> f32 ,
212+ {
213+ let result_data: Vec < f32 > = self . data ( ) . iter ( ) . map ( |x| op ( * x) ) . collect ( ) ;
214+ return Tensor :: new ( self . shape ( ) . clone ( ) , result_data) . unwrap ( ) ;
215+ }
216+
209217 pub fn exp ( & self ) -> Tensor {
210- let result_data: Vec < f32 > = self . data ( ) . iter ( ) . map ( |& x| x. exp ( ) ) . collect ( ) ;
211- Tensor :: new ( self . shape ( ) . clone ( ) , result_data) . unwrap ( )
218+ self . unary_op ( |x| x. exp ( ) )
212219 }
213220
214221 pub fn log ( & self ) -> Tensor {
215- let result_data: Vec < f32 > = self
216- . data ( )
217- . iter ( )
218- . map ( |& x| {
219- if x == 0.0 {
220- f32:: NEG_INFINITY // log(0) -> -inf
221- } else if x < 0.0 {
222- f32:: NAN // log of negative numbers is undefined
223- } else {
224- x. ln ( )
225- }
226- } )
227- . collect ( ) ;
228- Tensor :: new ( self . shape ( ) . clone ( ) , result_data) . unwrap ( )
222+ self . unary_op ( |x| {
223+ if x == 0.0 {
224+ f32:: NEG_INFINITY // log(0) -> -inf
225+ } else if x < 0.0 {
226+ f32:: NAN // log of negative numbers is undefined
227+ } else {
228+ x. ln ( )
229+ }
230+ } )
229231 }
230232
231233 pub fn relu ( & self ) -> Tensor {
232- let result_data: Vec < f32 > = self
233- . data ( )
234- . iter ( )
235- . map ( |& x| if x > 0.0_f32 { x } else { 0.0_f32 } )
236- . collect ( ) ;
237- Tensor :: new ( self . shape ( ) . clone ( ) , result_data) . unwrap ( )
234+ self . unary_op ( |x| if x > 0.0_f32 { x } else { 0.0_f32 } )
238235 }
239236
240237 /* BINARY OPS */
You can’t perform that action at this time.
0 commit comments