Skip to content

Commit 152a150

Browse files
authored
Merge pull request #13 from PaytonWebber/feature/tensor-relu
add relu unary operation
2 parents b7e64aa + f3323d5 commit 152a150

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

tensor/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,15 @@ impl Tensor {
193193
Tensor::new(self.shape().clone(), result_data).unwrap()
194194
}
195195

196+
pub fn relu(&self) -> Tensor {
197+
let result_data: Vec<f32> = self
198+
.data()
199+
.iter()
200+
.map(|&x| if x > 0.0_f32 { x } else { 0.0_f32 })
201+
.collect();
202+
Tensor::new(self.shape().clone(), result_data).unwrap()
203+
}
204+
196205
/* BINARY OPS */
197206

198207
pub fn add(&self, other: &Tensor) -> Result<Tensor, &'static str> {

tensor/tests/tensor_core_test.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,17 @@ fn tensor_log() {
291291
assert_eq!(expected_result, *result.data());
292292
}
293293

294+
#[test]
295+
fn tensor_relu() {
296+
let shape = vec![2, 3];
297+
let data: Vec<f32> = vec![1.0, -2.0, 3.0, 4.0, -5.0, 6.0];
298+
let a = Tensor::new(shape, data).unwrap();
299+
300+
let result = a.relu();
301+
let expected_data: Vec<f32> = vec![1.0, 0.0, 3.0, 4.0, 0.0, 6.0];
302+
assert_eq!(expected_data, *result.data());
303+
}
304+
294305
/* BINARY OPS */
295306

296307
#[test]

0 commit comments

Comments
 (0)