Skip to content

Commit 10ff62d

Browse files
authored
refactor unary ops to use generic function (#15)
1 parent c45aaf4 commit 10ff62d

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

tensor/src/lib.rs

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -206,35 +206,32 @@ impl Tensor {
206206

207207
/* UNARY OPS */
208208

209+
fn unary_op<F>(&self, op: F) -> Tensor
210+
where
211+
F: Fn(f32) -> f32,
212+
{
213+
let result_data: Vec<f32> = self.data().iter().map(|x| op(*x)).collect();
214+
return Tensor::new(self.shape().clone(), result_data).unwrap();
215+
}
216+
209217
pub fn exp(&self) -> Tensor {
210-
let result_data: Vec<f32> = self.data().iter().map(|&x| x.exp()).collect();
211-
Tensor::new(self.shape().clone(), result_data).unwrap()
218+
self.unary_op(|x| x.exp())
212219
}
213220

214221
pub fn log(&self) -> Tensor {
215-
let result_data: Vec<f32> = self
216-
.data()
217-
.iter()
218-
.map(|&x| {
219-
if x == 0.0 {
220-
f32::NEG_INFINITY // log(0) -> -inf
221-
} else if x < 0.0 {
222-
f32::NAN // log of negative numbers is undefined
223-
} else {
224-
x.ln()
225-
}
226-
})
227-
.collect();
228-
Tensor::new(self.shape().clone(), result_data).unwrap()
222+
self.unary_op(|x| {
223+
if x == 0.0 {
224+
f32::NEG_INFINITY // log(0) -> -inf
225+
} else if x < 0.0 {
226+
f32::NAN // log of negative numbers is undefined
227+
} else {
228+
x.ln()
229+
}
230+
})
229231
}
230232

231233
pub fn relu(&self) -> Tensor {
232-
let result_data: Vec<f32> = self
233-
.data()
234-
.iter()
235-
.map(|&x| if x > 0.0_f32 { x } else { 0.0_f32 })
236-
.collect();
237-
Tensor::new(self.shape().clone(), result_data).unwrap()
234+
self.unary_op(|x| if x > 0.0_f32 { x } else { 0.0_f32 })
238235
}
239236

240237
/* BINARY OPS */

0 commit comments

Comments
 (0)