Skip to content

Commit 5756f9d

Browse files
committed
test(deep_causality_tensor): Improved test coverage of implementation for Einstein Sum Convention.
Signed-off-by: Marvin Hansen <[email protected]>
1 parent 47fe93b commit 5756f9d

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

deep_causality_tensor/tests/types/causal_tensor/op_tensor_ein_sum_tests.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@ fn test_ein_sum_mat_mul() {
2626
assert_eq!(result, expected);
2727
}
2828

29+
#[test]
30+
fn test_ein_sum_contraction() {
31+
let lhs = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
32+
let rhs = utils_tests::matrix_tensor(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
33+
let expected = utils_tests::matrix_tensor(vec![19.0, 22.0, 43.0, 50.0], 2, 2);
34+
35+
let ast = EinSumAST::with_children(
36+
EinSumOp::Contraction {
37+
lhs_axes: vec![1],
38+
rhs_axes: vec![0],
39+
},
40+
vec![EinSumOp::tensor_source(lhs), EinSumOp::tensor_source(rhs)],
41+
);
42+
let result = CausalTensor::ein_sum(&ast).unwrap();
43+
assert_eq!(result, expected);
44+
}
45+
2946
#[test]
3047
fn test_ein_sum_dot_prod() {
3148
let lhs = utils_tests::vector_tensor(vec![1.0, 2.0, 3.0]);

0 commit comments

Comments
 (0)