Skip to content

Commit 705cf59

Browse files
authored
Add generic numeric type so that tensors can be created with f32 and f64 (#21)
1 parent 877a1a7 commit 705cf59

File tree

10 files changed

+310
-226
lines changed

10 files changed

+310
-226
lines changed

tensor/examples/generic_test.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
use tensor::{Tensor, TensorF32, TensorF64};
2+
3+
fn main() {
4+
// Test f32 tensors
5+
let a = TensorF32::ones(vec![2, 2]);
6+
let b = TensorF32::zeros(vec![2, 2]);
7+
println!("f32 tensor:\n{}", &a + &b);
8+
9+
// Test f64 tensors
10+
let a_f64 = TensorF64::ones(vec![2, 2]);
11+
let b_f64 = TensorF64::zeros(vec![2, 2]);
12+
println!("f64 tensor:\n{}", &a_f64 + &b_f64);
13+
14+
// Test with explicit generic syntax
15+
let a_explicit: Tensor<f32> = Tensor::ones(vec![2, 2]);
16+
let b_explicit: Tensor<f32> = Tensor::new(vec![2, 2], vec![2.5, 3.0, 1.5, 4.0]).unwrap();
17+
println!("Explicit f32 tensor:\n{}", &a_explicit + &b_explicit);
18+
}

tensor/src/core/mod.rs

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,97 @@ mod traits;
44
mod utils;
55

66
use errors::TensorError;
7+
use std::fmt;
78
use utils::calculate_strides;
89

9-
pub struct Tensor {
10+
pub trait Numeric:
11+
Copy
12+
+ fmt::Display
13+
+ fmt::Debug
14+
+ std::ops::Add<Output = Self>
15+
+ std::ops::Sub<Output = Self>
16+
+ std::ops::Mul<Output = Self>
17+
+ std::ops::Div<Output = Self>
18+
+ std::ops::AddAssign
19+
+ std::ops::SubAssign
20+
+ std::ops::MulAssign
21+
+ std::ops::DivAssign
22+
+ PartialEq
23+
+ PartialOrd
24+
+ 'static
25+
{
26+
fn zero() -> Self;
27+
fn one() -> Self;
28+
fn exp(self) -> Self;
29+
fn ln(self) -> Self;
30+
fn neg_infinity() -> Self;
31+
fn nan() -> Self;
32+
fn is_sign_negative(self) -> bool;
33+
fn from_usize(val: usize) -> Self;
34+
}
35+
36+
impl Numeric for f32 {
37+
fn zero() -> Self {
38+
0.0
39+
}
40+
fn one() -> Self {
41+
1.0
42+
}
43+
fn exp(self) -> Self {
44+
self.exp()
45+
}
46+
fn ln(self) -> Self {
47+
self.ln()
48+
}
49+
fn neg_infinity() -> Self {
50+
f32::NEG_INFINITY
51+
}
52+
fn nan() -> Self {
53+
f32::NAN
54+
}
55+
fn is_sign_negative(self) -> bool {
56+
self.is_sign_negative()
57+
}
58+
fn from_usize(val: usize) -> Self {
59+
val as f32
60+
}
61+
}
62+
63+
impl Numeric for f64 {
64+
fn zero() -> Self {
65+
0.0
66+
}
67+
fn one() -> Self {
68+
1.0
69+
}
70+
fn exp(self) -> Self {
71+
self.exp()
72+
}
73+
fn ln(self) -> Self {
74+
self.ln()
75+
}
76+
fn neg_infinity() -> Self {
77+
f64::NEG_INFINITY
78+
}
79+
fn nan() -> Self {
80+
f64::NAN
81+
}
82+
fn is_sign_negative(self) -> bool {
83+
self.is_sign_negative()
84+
}
85+
fn from_usize(val: usize) -> Self {
86+
val as f64
87+
}
88+
}
89+
90+
pub struct Tensor<T: Numeric = f32> {
1091
shape: Vec<usize>,
1192
strides: Vec<usize>,
12-
data: Vec<f32>,
93+
data: Vec<T>,
1394
}
1495

15-
impl Tensor {
16-
pub fn new<S>(shape: S, data: Vec<f32>) -> Result<Self, TensorError>
96+
impl<T: Numeric> Tensor<T> {
97+
pub fn new<S>(shape: S, data: Vec<T>) -> Result<Self, TensorError>
1798
where
1899
S: Into<Vec<usize>>,
19100
{
@@ -42,7 +123,7 @@ impl Tensor {
42123
Tensor {
43124
shape: shape_vec,
44125
strides,
45-
data: vec![0.0; num_elements],
126+
data: vec![T::zero(); num_elements],
46127
}
47128
}
48129

@@ -56,11 +137,11 @@ impl Tensor {
56137
Tensor {
57138
shape: shape_vec,
58139
strides,
59-
data: vec![1.0; num_elements],
140+
data: vec![T::one(); num_elements],
60141
}
61142
}
62143

63-
pub fn get(&self, indices: &[usize]) -> Option<&f32> {
144+
pub fn get(&self, indices: &[usize]) -> Option<&T> {
64145
if indices.len() != self.shape.len() {
65146
return None;
66147
}
@@ -83,11 +164,11 @@ impl Tensor {
83164
&self.strides
84165
}
85166

86-
pub fn data(&self) -> &[f32] {
167+
pub fn data(&self) -> &[T] {
87168
&self.data
88169
}
89170

90-
pub fn data_mut(&mut self) -> &mut Vec<f32> {
171+
pub fn data_mut(&mut self) -> &mut Vec<T> {
91172
&mut self.data
92173
}
93174
}

tensor/src/core/ops/binary.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use super::broadcast::{compute_broadcast_shape_and_strides, is_broadcastable};
22
use crate::core::utils::unravel_index;
3-
use crate::core::TensorError;
3+
use crate::core::{TensorError, Numeric};
44
use crate::Tensor;
55

6-
impl Tensor {
7-
fn binary_op<F>(&self, other: &Tensor, op: F) -> Result<Tensor, TensorError>
6+
impl<T: Numeric> Tensor<T> {
7+
fn binary_op<F>(&self, other: &Tensor<T>, op: F) -> Result<Tensor<T>, TensorError>
88
where
9-
F: Fn(f32, f32) -> f32,
9+
F: Fn(T, T) -> T,
1010
{
1111
let self_shape = self.shape();
1212
let other_shape = other.shape();
@@ -17,7 +17,7 @@ impl Tensor {
1717
}
1818

1919
if self_shape == other_shape {
20-
let result_data: Vec<f32> = self
20+
let result_data: Vec<T> = self
2121
.data
2222
.iter()
2323
.zip(other.data.iter())
@@ -32,7 +32,7 @@ impl Tensor {
3232
let self_data = self.data();
3333
let other_data = other.data();
3434
let result_size: usize = bc_shape.iter().product();
35-
let mut result_data: Vec<f32> = Vec::with_capacity(result_size);
35+
let mut result_data: Vec<T> = Vec::with_capacity(result_size);
3636

3737
for i in 0..result_size {
3838
let multi_idx = unravel_index(i, &bc_shape);
@@ -50,9 +50,9 @@ impl Tensor {
5050
Tensor::new(bc_shape, result_data)
5151
}
5252

53-
fn binary_op_inplace<F>(&mut self, other: &Tensor, op: F)
53+
fn binary_op_inplace<F>(&mut self, other: &Tensor<T>, op: F)
5454
where
55-
F: Fn(&mut f32, f32),
55+
F: Fn(&mut T, T),
5656
{
5757
let self_shape = self.shape();
5858
let other_shape = other.shape();
@@ -69,39 +69,39 @@ impl Tensor {
6969
});
7070
}
7171

72-
pub fn add(&self, other: &Tensor) -> Result<Tensor, TensorError> {
72+
pub fn add(&self, other: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
7373
self.binary_op(other, |a, b| a + b)
7474
}
7575

76-
pub fn add_inplace(&mut self, other: &Tensor) {
76+
pub fn add_inplace(&mut self, other: &Tensor<T>) {
7777
self.binary_op_inplace(other, |a, b| *a += b);
7878
}
7979

80-
pub fn sub(&self, other: &Tensor) -> Result<Tensor, TensorError> {
80+
pub fn sub(&self, other: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
8181
self.binary_op(other, |a, b| a - b)
8282
}
8383

84-
pub fn sub_inplace(&mut self, other: &Tensor) {
84+
pub fn sub_inplace(&mut self, other: &Tensor<T>) {
8585
self.binary_op_inplace(other, |a, b| *a -= b);
8686
}
8787

88-
pub fn mul(&self, other: &Tensor) -> Result<Tensor, TensorError> {
88+
pub fn mul(&self, other: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
8989
self.binary_op(other, |a, b| a * b)
9090
}
9191

92-
pub fn mul_inplace(&mut self, other: &Tensor) {
92+
pub fn mul_inplace(&mut self, other: &Tensor<T>) {
9393
self.binary_op_inplace(other, |a, b| *a *= b);
9494
}
9595

96-
pub fn div(&self, other: &Tensor) -> Result<Tensor, TensorError> {
96+
pub fn div(&self, other: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
9797
self.binary_op(other, |a, b| a / b)
9898
}
9999

100-
pub fn div_inplace(&mut self, other: &Tensor) {
100+
pub fn div_inplace(&mut self, other: &Tensor<T>) {
101101
self.binary_op_inplace(other, |a, b| *a /= b);
102102
}
103103

104-
pub fn matmul(&self, other: &Tensor) -> Result<Tensor, TensorError> {
104+
pub fn matmul(&self, other: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
105105
let lhs_shape = self.shape();
106106
let rhs_shape = other.shape();
107107
if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
@@ -120,7 +120,7 @@ impl Tensor {
120120

121121
let lhs_data = self.data();
122122
let rhs_data = other.data();
123-
let mut result_data: Vec<f32> = vec![0.0_f32; rows_left * cols_right];
123+
let mut result_data: Vec<T> = vec![T::zero(); rows_left * cols_right];
124124
for i in 0..rows_left {
125125
for k in 0..cols_left {
126126
for j in 0..cols_right {
@@ -129,6 +129,6 @@ impl Tensor {
129129
}
130130
}
131131
}
132-
Ok(Tensor::new(vec![rows_left, cols_right], result_data).unwrap())
132+
Tensor::new(vec![rows_left, cols_right], result_data)
133133
}
134134
}

tensor/src/core/ops/movement.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use crate::core::{calculate_strides, TensorError};
1+
use crate::core::{calculate_strides, TensorError, Numeric};
22
use crate::Tensor;
33

4-
impl Tensor {
4+
impl<T: Numeric> Tensor<T> {
55
pub fn reshape(&mut self, shape: &[usize]) -> Result<(), TensorError> {
66
let new_length: usize = shape.iter().product();
77
let current_length: usize = self.shape.iter().product();
@@ -51,7 +51,7 @@ impl Tensor {
5151
}
5252

5353
let (m, n) = (self.shape[0], self.shape[1]);
54-
let mut new_data = vec![0.0_f32; self.data.len()];
54+
let mut new_data = vec![T::zero(); self.data.len()];
5555
for i in 0..m {
5656
for j in 0..n {
5757
let old_idx = i * n + j;

tensor/src/core/ops/reduce.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
use crate::core::utils::unravel_index;
2-
use crate::core::TensorError;
2+
use crate::core::{TensorError, Numeric};
33
use crate::Tensor;
44

5-
impl Tensor {
6-
pub fn sum(&self) -> Tensor {
7-
let sum: f32 = self.data().iter().sum();
5+
impl<T: Numeric> Tensor<T> {
6+
pub fn sum(&self) -> Tensor<T> {
7+
let sum: T = self.data().iter().fold(T::zero(), |acc, &x| acc + x);
88
Tensor::new(vec![1], vec![sum]).unwrap()
99
}
1010

11-
pub fn sum_dim(&self, dim: usize) -> Result<Tensor, TensorError> {
11+
pub fn sum_dim(&self, dim: usize) -> Result<Tensor<T>, TensorError> {
1212
let self_data = self.data();
1313
let self_shape = self.shape();
1414
let self_strides = self.strides();
@@ -22,7 +22,7 @@ impl Tensor {
2222
let dim_size = result_shape.remove(dim);
2323

2424
let result_size: usize = result_shape.iter().product();
25-
let mut result_data = vec![0.0_f32; result_size];
25+
let mut result_data = vec![T::zero(); result_size];
2626

2727
for i in 0..result_size {
2828
let result_multi_idx = unravel_index(i, &result_shape);
@@ -36,7 +36,7 @@ impl Tensor {
3636
j += 1;
3737
}
3838
}
39-
let mut sum = 0.0_f32;
39+
let mut sum = T::zero();
4040
for k in 0..dim_size {
4141
full_multi_idx[dim] = k;
4242
let mut offset = 0;
@@ -50,18 +50,18 @@ impl Tensor {
5050
Tensor::new(result_shape, result_data)
5151
}
5252

53-
pub fn mean(&self) -> Tensor {
54-
let sum: Tensor = self.sum();
55-
&sum / self.shape().iter().product::<usize>() as f32
53+
pub fn mean(&self) -> Tensor<T> {
54+
let sum: Tensor<T> = self.sum();
55+
&sum / T::from_usize(self.shape().iter().product::<usize>())
5656
}
5757

58-
pub fn mean_dim(&self, dim: usize) -> Result<Tensor, TensorError> {
58+
pub fn mean_dim(&self, dim: usize) -> Result<Tensor<T>, TensorError> {
5959
if self.shape().len() < dim {
6060
return Err(TensorError::IndexError(
6161
"Dimension out of range for the tensor".to_string(),
6262
));
6363
}
64-
let sum: Tensor = self.sum_dim(dim).unwrap();
65-
Ok(&sum / self.shape()[dim] as f32)
64+
let sum: Tensor<T> = self.sum_dim(dim).unwrap();
65+
Ok(&sum / T::from_usize(self.shape()[dim]))
6666
}
6767
}

tensor/src/core/ops/unary.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
1-
use crate::Tensor;
1+
use crate::{Tensor, core::Numeric};
22

3-
impl Tensor {
4-
fn unary_op<F>(&self, op: F) -> Tensor
3+
impl<T: Numeric> Tensor<T> {
4+
fn unary_op<F>(&self, op: F) -> Tensor<T>
55
where
6-
F: Fn(f32) -> f32,
6+
F: Fn(T) -> T,
77
{
8-
let result_data: Vec<f32> = self.data().iter().map(|x| op(*x)).collect();
9-
return Tensor::new(self.shape(), result_data).unwrap();
8+
let result_data: Vec<T> = self.data().iter().map(|x| op(*x)).collect();
9+
Tensor::new(self.shape(), result_data).unwrap()
1010
}
1111

12-
pub fn exp(&self) -> Tensor {
12+
pub fn exp(&self) -> Tensor<T> {
1313
self.unary_op(|x| x.exp())
1414
}
1515

16-
pub fn log(&self) -> Tensor {
16+
pub fn log(&self) -> Tensor<T> {
1717
self.unary_op(|x| {
18-
if x == 0.0 {
19-
f32::NEG_INFINITY // log(0) -> -inf
20-
} else if x < 0.0 {
21-
f32::NAN // log of negative numbers is undefined
18+
if x == T::zero() {
19+
T::neg_infinity()
20+
} else if x.is_sign_negative() {
21+
T::nan()
2222
} else {
2323
x.ln()
2424
}
2525
})
2626
}
2727

28-
pub fn relu(&self) -> Tensor {
29-
self.unary_op(|x| if x > 0.0_f32 { x } else { 0.0_f32 })
28+
pub fn relu(&self) -> Tensor<T> {
29+
self.unary_op(|x| if x > T::zero() { x } else { T::zero() })
3030
}
3131
}

0 commit comments

Comments
 (0)