Skip to content

Commit 33b14b3

Browse files
committed
feat(deep_causality_tensor): Updated EinSumOp for better ergonomics.
Signed-off-by: Marvin Hansen <[email protected]>
1 parent b59358b commit 33b14b3

File tree

8 files changed

+82
-58
lines changed

8 files changed

+82
-58
lines changed

deep_causality_tensor/README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,7 @@ fn main() {
122122
let vec2_tensor = CausalTensor::new(vec2_data, vec2_shape).unwrap();
123123
124124
// Execute the Einstein summation for dot product
125-
let result_dot_prod = CausalTensor::ein_sum(&EinSumOp::dot_prod(
126-
vec1_tensor.clone(),
127-
vec2_tensor.clone(),
128-
))
129-
.unwrap();
125+
let result_dot_prod = CausalTensor::ein_sum(&EinSumOp::dot_prod(vec1_tensor, vec2_tensor)).unwrap();
130126
println!("Result of Dot Product:\n{:?}", result_dot_prod);
131127
}
132128
```

deep_causality_tensor/examples/ein_sum_causal_tensor.rs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ fn main() {
1818
println!("LHS Tensor:\n{:?}", lhs_tensor);
1919
println!("RHS Tensor:\n{:?}", rhs_tensor);
2020

21-
let result_mat_mul =
22-
CausalTensor::ein_sum(&EinSumOp::mat_mul(lhs_tensor.clone(), rhs_tensor.clone())).unwrap();
21+
let result_mat_mul = CausalTensor::ein_sum(&EinSumOp::mat_mul(lhs_tensor, rhs_tensor)).unwrap();
2322
println!("Result of Matrix Multiplication:\n{:?}", result_mat_mul);
2423
let expected_mat_mul = CausalTensor::new(vec![19.0, 22.0, 43.0, 50.0], vec![2, 2]).unwrap();
2524
assert_eq!(result_mat_mul, expected_mat_mul);
@@ -37,11 +36,8 @@ fn main() {
3736
println!("Vector 1:\n{:?}", vec1_tensor);
3837
println!("Vector 2:\n{:?}", vec2_tensor);
3938

40-
let result_dot_prod = CausalTensor::ein_sum(&EinSumOp::dot_prod(
41-
vec1_tensor.clone(),
42-
vec2_tensor.clone(),
43-
))
44-
.unwrap();
39+
let result_dot_prod =
40+
CausalTensor::ein_sum(&EinSumOp::dot_prod(vec1_tensor, vec2_tensor)).unwrap();
4541
println!("Result of Dot Product:\n{:?}", result_dot_prod);
4642
let expected_dot_prod = CausalTensor::new(vec![32.0], vec![]).unwrap();
4743
assert_eq!(result_dot_prod, expected_dot_prod);
@@ -53,7 +49,7 @@ fn main() {
5349
let trace_tensor = CausalTensor::new(trace_data, trace_shape).unwrap();
5450

5551
println!("Tensor for Trace:\n{:?}", trace_tensor);
56-
let result_trace = CausalTensor::ein_sum(&EinSumOp::trace(trace_tensor.clone(), 0, 1)).unwrap();
52+
let result_trace = CausalTensor::ein_sum(&EinSumOp::trace(trace_tensor, 0, 1)).unwrap();
5753
println!("Result of Trace (axes 0, 1):\n{:?}", result_trace);
5854
let expected_trace = CausalTensor::new(vec![5.0], vec![]).unwrap();
5955
assert_eq!(result_trace, expected_trace);
@@ -72,8 +68,8 @@ fn main() {
7268
println!("RHS Tensor for Element-wise Product:\n{:?}", ew_rhs_tensor);
7369

7470
let result_ew_prod = CausalTensor::ein_sum(&EinSumOp::element_wise_product(
75-
ew_lhs_tensor.clone(),
76-
ew_rhs_tensor.clone(),
71+
ew_lhs_tensor,
72+
ew_rhs_tensor,
7773
))
7874
.unwrap();
7975
println!("Result of Element-wise Product:\n{:?}", result_ew_prod);
@@ -100,11 +96,9 @@ fn main() {
10096
println!("LHS Tensor for Batch MatMul:\n{:?}", bmm_lhs_tensor);
10197
println!("RHS Tensor for Batch MatMul:\n{:?}", bmm_rhs_tensor);
10298

103-
let result_bmm = CausalTensor::ein_sum(&EinSumOp::batch_mat_mul(
104-
bmm_lhs_tensor.clone(),
105-
bmm_rhs_tensor.clone(),
106-
))
107-
.unwrap();
99+
let result_bmm =
100+
CausalTensor::ein_sum(&EinSumOp::batch_mat_mul(bmm_lhs_tensor, bmm_rhs_tensor)).unwrap();
101+
108102
println!("Result of Batch Matrix Multiplication:\n{:?}", result_bmm);
109103
let expected_bmm = CausalTensor::new(
110104
vec![

deep_causality_tensor/src/types/causal_tensor/from/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,10 @@ impl<'a, T: Clone> From<&'a T> for CausalTensor<T> {
1919
.expect("Failed to create scalar CausalTensor from &T")
2020
}
2121
}
22+
23+
impl<'a, T: Clone> From<&'a CausalTensor<T>> for CausalTensor<T> {
24+
/// Creates a new `CausalTensor` by cloning an existing `CausalTensor` reference.
25+
fn from(item: &'a CausalTensor<T>) -> Self {
26+
item.clone()
27+
}
28+
}

deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum/ein_sum_op.rs

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ impl<T> EinSumOp<T> {
7979
/// # Returns
8080
///
8181
/// An `EinSumAST<T>` node representing the contraction operation.
82-
pub fn contraction(
83-
lhs: CausalTensor<T>,
84-
rhs: CausalTensor<T>,
82+
pub fn contraction<L: Into<CausalTensor<T>>, R: Into<CausalTensor<T>>>(
83+
lhs: L,
84+
rhs: R,
8585
lhs_axes: Vec<usize>,
8686
rhs_axes: Vec<usize>,
8787
) -> EinSumAST<T> {
88-
let lhs_leaf = EinSumOp::tensor_source(lhs);
89-
let rhs_leaf = EinSumOp::tensor_source(rhs);
88+
let lhs_leaf = EinSumOp::tensor_source(lhs.into());
89+
let rhs_leaf = EinSumOp::tensor_source(rhs.into());
9090
EinSumAST::with_children(
9191
EinSumOp::Contraction { lhs_axes, rhs_axes },
9292
vec![lhs_leaf, rhs_leaf],
@@ -105,8 +105,8 @@ impl<T> EinSumOp<T> {
105105
/// # Returns
106106
///
107107
/// An `EinSumAST<T>` node representing the reduction operation.
108-
pub fn reduction(operand: CausalTensor<T>, axes: Vec<usize>) -> EinSumAST<T> {
109-
let operand_leaf = EinSumOp::tensor_source(operand);
108+
pub fn reduction<O: Into<CausalTensor<T>>>(operand: O, axes: Vec<usize>) -> EinSumAST<T> {
109+
let operand_leaf = EinSumOp::tensor_source(operand.into());
110110
EinSumAST::with_children(EinSumOp::Reduction { axes }, vec![operand_leaf])
111111
}
112112

@@ -122,9 +122,12 @@ impl<T> EinSumOp<T> {
122122
/// # Returns
123123
///
124124
/// An `EinSumAST<T>` node representing the matrix multiplication operation.
125-
pub fn mat_mul(lhs: CausalTensor<T>, rhs: CausalTensor<T>) -> EinSumAST<T> {
126-
let lhs_leaf = EinSumOp::tensor_source(lhs);
127-
let rhs_leaf = EinSumOp::tensor_source(rhs);
125+
pub fn mat_mul<L: Into<CausalTensor<T>>, R: Into<CausalTensor<T>>>(
126+
lhs: L,
127+
rhs: R,
128+
) -> EinSumAST<T> {
129+
let lhs_leaf = EinSumOp::tensor_source(lhs.into());
130+
let rhs_leaf = EinSumOp::tensor_source(rhs.into());
128131
EinSumAST::with_children(EinSumOp::MatMul, vec![lhs_leaf, rhs_leaf])
129132
}
130133

@@ -140,9 +143,12 @@ impl<T> EinSumOp<T> {
140143
/// # Returns
141144
///
142145
/// An `EinSumAST<T>` node representing the dot product operation.
143-
pub fn dot_prod(lhs: CausalTensor<T>, rhs: CausalTensor<T>) -> EinSumAST<T> {
144-
let lhs_leaf = EinSumOp::tensor_source(lhs);
145-
let rhs_leaf = EinSumOp::tensor_source(rhs);
146+
pub fn dot_prod<L: Into<CausalTensor<T>>, R: Into<CausalTensor<T>>>(
147+
lhs: L,
148+
rhs: R,
149+
) -> EinSumAST<T> {
150+
let lhs_leaf = EinSumOp::tensor_source(lhs.into());
151+
let rhs_leaf = EinSumOp::tensor_source(rhs.into());
146152
EinSumAST::with_children(EinSumOp::DotProd, vec![lhs_leaf, rhs_leaf])
147153
}
148154

@@ -159,8 +165,8 @@ impl<T> EinSumOp<T> {
159165
/// # Returns
160166
///
161167
/// An `EinSumAST<T>` node representing the trace operation.
162-
pub fn trace(operand: CausalTensor<T>, axes1: usize, axes2: usize) -> EinSumAST<T> {
163-
let operand_leaf = EinSumOp::tensor_source(operand);
168+
pub fn trace<O: Into<CausalTensor<T>>>(operand: O, axes1: usize, axes2: usize) -> EinSumAST<T> {
169+
let operand_leaf = EinSumOp::tensor_source(operand.into());
164170
EinSumAST::with_children(EinSumOp::Trace { axes1, axes2 }, vec![operand_leaf])
165171
}
166172

@@ -176,9 +182,12 @@ impl<T> EinSumOp<T> {
176182
/// # Returns
177183
///
178184
/// An `EinSumAST<T>` node representing the tensor product operation.
179-
pub fn tensor_product(lhs: CausalTensor<T>, rhs: CausalTensor<T>) -> EinSumAST<T> {
180-
let lhs_leaf = EinSumOp::tensor_source(lhs);
181-
let rhs_leaf = EinSumOp::tensor_source(rhs);
185+
pub fn tensor_product<L: Into<CausalTensor<T>>, R: Into<CausalTensor<T>>>(
186+
lhs: L,
187+
rhs: R,
188+
) -> EinSumAST<T> {
189+
let lhs_leaf = EinSumOp::tensor_source(lhs.into());
190+
let rhs_leaf = EinSumOp::tensor_source(rhs.into());
182191
EinSumAST::with_children(EinSumOp::TensorProduct, vec![lhs_leaf, rhs_leaf])
183192
}
184193

@@ -194,9 +203,12 @@ impl<T> EinSumOp<T> {
194203
/// # Returns
195204
///
196205
/// An `EinSumAST<T>` node representing the element-wise product operation.
197-
pub fn element_wise_product(lhs: CausalTensor<T>, rhs: CausalTensor<T>) -> EinSumAST<T> {
198-
let lhs_leaf = EinSumOp::tensor_source(lhs);
199-
let rhs_leaf = EinSumOp::tensor_source(rhs);
206+
pub fn element_wise_product<L: Into<CausalTensor<T>>, R: Into<CausalTensor<T>>>(
207+
lhs: L,
208+
rhs: R,
209+
) -> EinSumAST<T> {
210+
let lhs_leaf = EinSumOp::tensor_source(lhs.into());
211+
let rhs_leaf = EinSumOp::tensor_source(rhs.into());
200212
EinSumAST::with_children(EinSumOp::ElementWiseProduct, vec![lhs_leaf, rhs_leaf])
201213
}
202214

@@ -212,8 +224,8 @@ impl<T> EinSumOp<T> {
212224
/// # Returns
213225
///
214226
/// An `EinSumAST<T>` node representing the transpose operation.
215-
pub fn transpose(operand: CausalTensor<T>, new_order: Vec<usize>) -> EinSumAST<T> {
216-
let operand_leaf = EinSumOp::tensor_source(operand);
227+
pub fn transpose<O: Into<CausalTensor<T>>>(operand: O, new_order: Vec<usize>) -> EinSumAST<T> {
228+
let operand_leaf = EinSumOp::tensor_source(operand.into());
217229
EinSumAST::with_children(EinSumOp::Transpose { new_order }, vec![operand_leaf])
218230
}
219231

@@ -230,12 +242,12 @@ impl<T> EinSumOp<T> {
230242
/// # Returns
231243
///
232244
/// An `EinSumAST<T>` node representing the diagonal extraction operation.
233-
pub fn diagonal_extraction(
234-
operand: CausalTensor<T>,
245+
pub fn diagonal_extraction<O: Into<CausalTensor<T>>>(
246+
operand: O,
235247
axes1: usize,
236248
axes2: usize,
237249
) -> EinSumAST<T> {
238-
let operand_leaf = EinSumOp::tensor_source(operand);
250+
let operand_leaf = EinSumOp::tensor_source(operand.into());
239251
EinSumAST::with_children(
240252
EinSumOp::DiagonalExtraction { axes1, axes2 },
241253
vec![operand_leaf],
@@ -254,9 +266,12 @@ impl<T> EinSumOp<T> {
254266
/// # Returns
255267
///
256268
/// An `EinSumAST<T>` node representing the batch matrix multiplication operation.
257-
pub fn batch_mat_mul(lhs: CausalTensor<T>, rhs: CausalTensor<T>) -> EinSumAST<T> {
258-
let lhs_leaf = EinSumOp::tensor_source(lhs);
259-
let rhs_leaf = EinSumOp::tensor_source(rhs);
269+
pub fn batch_mat_mul<L: Into<CausalTensor<T>>, R: Into<CausalTensor<T>>>(
270+
lhs: L,
271+
rhs: R,
272+
) -> EinSumAST<T> {
273+
let lhs_leaf = EinSumOp::tensor_source(lhs.into());
274+
let rhs_leaf = EinSumOp::tensor_source(rhs.into());
260275
EinSumAST::with_children(EinSumOp::BatchMatMul, vec![lhs_leaf, rhs_leaf])
261276
}
262277
}

deep_causality_tensor/tests/types/causal_tensor/op_tensor_ein_sum_tests.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,19 @@ fn test_ein_sum_mat_mul() {
2121
let rhs = utils_tests::matrix_tensor(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
2222
let expected = utils_tests::matrix_tensor(vec![19.0, 22.0, 43.0, 50.0], 2, 2);
2323

24-
let ast = EinSumOp::mat_mul(lhs, rhs);
24+
let ast = EinSumOp::<f64>::mat_mul(lhs, rhs);
25+
let result = CausalTensor::ein_sum(&ast).unwrap();
26+
assert_eq!(result, expected);
27+
}
28+
29+
#[test]
30+
fn test_ein_sum_mat_mul_with_references() {
31+
let lhs_owned = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
32+
let rhs_owned = utils_tests::matrix_tensor(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
33+
let expected = utils_tests::matrix_tensor(vec![19.0, 22.0, 43.0, 50.0], 2, 2);
34+
35+
// Pass references to the EinSumOp::mat_mul method
36+
let ast = EinSumOp::<f64>::mat_mul(&lhs_owned, &rhs_owned);
2537
let result = CausalTensor::ein_sum(&ast).unwrap();
2638
assert_eq!(result, expected);
2739
}
@@ -59,7 +71,7 @@ fn test_ein_sum_trace() {
5971
let operand = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
6072
let expected = utils_tests::scalar_tensor(5.0);
6173

62-
let ast = EinSumOp::trace(operand, 0, 1);
74+
let ast = EinSumOp::<f64>::trace(operand, 0, 1);
6375
let result = CausalTensor::ein_sum(&ast).unwrap();
6476
assert_eq!(result, expected);
6577
}
@@ -142,7 +154,7 @@ fn test_ein_sum_error_propagation() {
142154
let lhs = utils_tests::vector_tensor(vec![1.0, 2.0]);
143155
let rhs = utils_tests::matrix_tensor(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
144156

145-
let ast = EinSumOp::mat_mul(lhs, rhs);
157+
let ast = EinSumOp::<f64>::mat_mul(lhs, rhs);
146158
let err = CausalTensor::ein_sum(&ast).unwrap_err();
147159
assert!(matches!(
148160
err,
@@ -154,7 +166,7 @@ fn test_ein_sum_error_propagation() {
154166

155167
// Test Trace with invalid axes, expecting error from trace
156168
let operand = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
157-
let ast = EinSumOp::trace(operand, 0, 0);
169+
let ast = EinSumOp::<f64>::trace(operand, 0, 0);
158170
let err = CausalTensor::ein_sum(&ast).unwrap_err();
159171
assert!(matches!(
160172
err,

deep_causality_tensor/tests/types/causal_tensor/op_tensor_reduction_tests.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ fn test_sum_axes_2d() {
2626
fn test_sum_axes_full_reduction() {
2727
let tensor = CausalTensor::new(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
2828
let sum_all = tensor.sum_axes(&[]).unwrap();
29-
assert_eq!(sum_all.shape(), &[]);
29+
assert_eq!(sum_all.shape(), &[] as &[usize]);
3030
assert_eq!(sum_all.as_slice(), &[21]);
3131
}
3232

@@ -60,7 +60,7 @@ fn test_sum_axes_empty_tensor() {
6060

6161
// Full reduction of an empty tensor should result in a scalar 0
6262
let sum_all = tensor.sum_axes(&[]).unwrap();
63-
assert_eq!(sum_all.shape(), &[]);
63+
assert_eq!(sum_all.shape(), &[] as &[usize]);
6464
assert_eq!(sum_all.as_slice(), &[0]);
6565
}
6666

@@ -92,7 +92,7 @@ fn test_mean_axes_2d() {
9292
fn test_mean_axes_full_reduction() {
9393
let tensor = CausalTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
9494
let mean_all = tensor.mean_axes(&[]).unwrap();
95-
assert_eq!(mean_all.shape(), &[]);
95+
assert_eq!(mean_all.shape(), &[] as &[usize]);
9696
assert_eq!(mean_all.as_slice(), &[3.5]);
9797
}
9898

deep_causality_tensor/tests/types/causal_tensor/op_tensor_shape_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,6 @@ fn test_permute_axes_scalar() {
165165
let tensor = CausalTensor::new(vec![42], vec![]).unwrap();
166166
// Permuting a 0-dim tensor
167167
let permuted = tensor.permute_axes(&[]).unwrap();
168-
assert_eq!(permuted.shape(), &[]);
168+
assert_eq!(permuted.shape(), &[] as &[usize]);
169169
assert_eq!(permuted, tensor);
170170
}

deep_causality_tensor/tests/types/causal_tensor/op_tensor_tensor_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ fn test_binary_op_scalar_tensors() {
183183
let t1 = CausalTensor::new(vec![10.0], vec![]).unwrap();
184184
let t2 = CausalTensor::new(vec![2.0], vec![]).unwrap();
185185
let expected_data = vec![20.0];
186-
let expected_shape = vec![];
186+
let expected_shape: Vec<usize> = vec![];
187187

188188
let result = (&t1 * &t2).unwrap();
189189
assert_eq!(result.as_slice(), expected_data.as_slice());

0 commit comments

Comments
 (0)