Skip to content

Commit d61ba78

Browse files
Merge pull request #375 from marvin-hansen/main
feat(deep_causality_tensor): Implement and document Einstein Sum Convention
2 parents 8c23dbf + 695299b commit d61ba78

31 files changed

+2709
-71
lines changed

.bazelignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
.cargo
22
.gemini
3+
.gemini_security
34
.git
45
.github
56
.specify

.bazelversion

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8.4.1
1+
8.4.2

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

MODULE.bazel.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

deep_causality_tensor/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ rust_library(
1111
],
1212
visibility = ["//visibility:public"],
1313
deps = [
14+
"//deep_causality_ast",
1415
"//deep_causality_haft",
1516
"//deep_causality_num",
1617
],

deep_causality_tensor/Cargo.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@ path = "examples/basic_causal_tensor.rs"
2727
name = "effect_system_causal_tensor"
2828
path = "examples/effect_system_causal_tensor.rs"
2929

30+
[[example]]
31+
name = "ein_sum_causal_tensor"
32+
path = "examples/ein_sum_causal_tensor.rs"
33+
3034
[[example]]
3135
name = "functor_causal_tensor"
3236
path = "examples/functor_causal_tensor.rs"
3337

38+
[dependencies.deep_causality_ast]
39+
path = "../deep_causality_ast"
40+
version = "0.1"
3441

3542
[dependencies.deep_causality_haft]
3643
path = "../deep_causality_haft"

deep_causality_tensor/README.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,31 @@ tasks.
1717
* [Examples](../deep_causality_tensor/examples)
1818
* [Test](../deep_causality_tensor/tests)
1919

20+
## Examples
21+
22+
To run the examples, use `cargo run --example <example_name>`.
23+
24+
* **Applicative Causal Tensor**
25+
```bash
26+
cargo run --example applicative_causal_tensor
27+
```
28+
* **Basic Causal Tensor**
29+
```bash
30+
cargo run --example causal_tensor
31+
```
32+
* **Effect System Causal Tensor**
33+
```bash
34+
cargo run --example effect_system_causal_tensor
35+
```
36+
* **Einstein Summation Causal Tensor**
37+
```bash
38+
cargo run --example ein_sum_causal_tensor
39+
```
40+
* **Functor Causal Tensor**
41+
```bash
42+
cargo run --example functor_causal_tensor
43+
```
44+
2045
## Usage
2146

2247
`CausalTensor` is straightforward to use. You create it from a flat vector of data and a vector defining its shape.
@@ -62,6 +87,51 @@ fn main() {
6287
}
6388
```
6489
90+
## Einstein Sum (ein_sum)
91+
92+
The `ein_sum` function provides a powerful and flexible way to perform various tensor operations, including matrix multiplication, dot products, and more, by constructing an Abstract Syntax Tree (AST) of operations.
93+
94+
```rust
95+
use deep_causality_tensor::CausalTensor;
96+
use deep_causality_tensor::types::causal_tensor::op_tensor_ein_sum::EinSumOp;
97+
98+
fn main() {
99+
// Example: Matrix Multiplication using ein_sum
100+
let lhs_data = vec![1.0, 2.0, 3.0, 4.0];
101+
let lhs_tensor = CausalTensor::new(lhs_data, vec![2, 2]).unwrap();
102+
103+
let rhs_data = vec![5.0, 6.0, 7.0, 8.0];
104+
let rhs_tensor = CausalTensor::new(rhs_data, vec![2, 2]).unwrap();
105+
106+
// Construct the AST for matrix multiplication
107+
let mat_mul_ast = EinSumOp::mat_mul(lhs_tensor, rhs_tensor);
108+
109+
// Execute the Einstein summation
110+
let result = CausalTensor::ein_sum(&mat_mul_ast).unwrap();
111+
112+
println!("Result of Matrix Multiplication:\n{:?}", result);
113+
// Expected: CausalTensor { data: [19.0, 22.0, 43.0, 50.0], shape: [2, 2], strides: [2, 1] }
114+
115+
// Example: Dot Product
116+
let vec1_data = vec![1.0, 2.0, 3.0];
117+
let vec1_shape = vec![3];
118+
let vec1_tensor = CausalTensor::new(vec1_data, vec1_shape).unwrap();
119+
120+
let vec2_data = vec![4.0, 5.0, 6.0];
121+
let vec2_shape = vec![3];
122+
let vec2_tensor = CausalTensor::new(vec2_data, vec2_shape).unwrap();
123+
124+
// Execute the Einstein summation for dot product
125+
let result_dot_prod = CausalTensor::ein_sum(&EinSumOp::dot_prod(
126+
vec1_tensor.clone(),
127+
vec2_tensor.clone(),
128+
))
129+
.unwrap();
130+
println!("Result of Dot Product:\n{:?}", result_dot_prod);
131+
}
132+
```
133+
134+
65135
## Functional Composition
66136
67137
Causal Tensor implements a Higher Kinded Type via the `deep_causality_haft` crate as Witness Type. When imported, the CausalTensorWitness type allows monadic composition and abstract type programming. For example, one can write generic functions that uniformly process tensors and other types:
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* SPDX-License-Identifier: MIT
3+
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4+
*/
5+
use deep_causality_tensor::{CausalTensor, EinSumOp};
6+
7+
fn main() {
8+
// Example 1: Matrix Multiplication
9+
println!("--- Example 1: Matrix Multiplication ---");
10+
let lhs_data = vec![1.0, 2.0, 3.0, 4.0];
11+
let lhs_shape = vec![2, 2];
12+
let lhs_tensor = CausalTensor::new(lhs_data, lhs_shape).unwrap();
13+
14+
let rhs_data = vec![5.0, 6.0, 7.0, 8.0];
15+
let rhs_shape = vec![2, 2];
16+
let rhs_tensor = CausalTensor::new(rhs_data, rhs_shape).unwrap();
17+
18+
println!("LHS Tensor:\n{:?}", lhs_tensor);
19+
println!("RHS Tensor:\n{:?}", rhs_tensor);
20+
21+
let result_mat_mul =
22+
CausalTensor::ein_sum(&EinSumOp::mat_mul(lhs_tensor.clone(), rhs_tensor.clone())).unwrap();
23+
println!("Result of Matrix Multiplication:\n{:?}", result_mat_mul);
24+
let expected_mat_mul = CausalTensor::new(vec![19.0, 22.0, 43.0, 50.0], vec![2, 2]).unwrap();
25+
assert_eq!(result_mat_mul, expected_mat_mul);
26+
27+
// Example 2: Dot Product
28+
println!("\n--- Example 2: Dot Product ---");
29+
let vec1_data = vec![1.0, 2.0, 3.0];
30+
let vec1_shape = vec![3];
31+
let vec1_tensor = CausalTensor::new(vec1_data, vec1_shape).unwrap();
32+
33+
let vec2_data = vec![4.0, 5.0, 6.0];
34+
let vec2_shape = vec![3];
35+
let vec2_tensor = CausalTensor::new(vec2_data, vec2_shape).unwrap();
36+
37+
println!("Vector 1:\n{:?}", vec1_tensor);
38+
println!("Vector 2:\n{:?}", vec2_tensor);
39+
40+
let result_dot_prod = CausalTensor::ein_sum(&EinSumOp::dot_prod(
41+
vec1_tensor.clone(),
42+
vec2_tensor.clone(),
43+
))
44+
.unwrap();
45+
println!("Result of Dot Product:\n{:?}", result_dot_prod);
46+
let expected_dot_prod = CausalTensor::new(vec![32.0], vec![]).unwrap();
47+
assert_eq!(result_dot_prod, expected_dot_prod);
48+
49+
// Example 3: Trace
50+
println!("\n--- Example 3: Trace ---");
51+
let trace_data = vec![1.0, 2.0, 3.0, 4.0];
52+
let trace_shape = vec![2, 2];
53+
let trace_tensor = CausalTensor::new(trace_data, trace_shape).unwrap();
54+
55+
println!("Tensor for Trace:\n{:?}", trace_tensor);
56+
let result_trace = CausalTensor::ein_sum(&EinSumOp::trace(trace_tensor.clone(), 0, 1)).unwrap();
57+
println!("Result of Trace (axes 0, 1):\n{:?}", result_trace);
58+
let expected_trace = CausalTensor::new(vec![5.0], vec![]).unwrap();
59+
assert_eq!(result_trace, expected_trace);
60+
61+
// Example 4: Element-wise Product
62+
println!("\n--- Example 4: Element-wise Product ---");
63+
let ew_lhs_data = vec![1.0, 2.0, 3.0];
64+
let ew_lhs_shape = vec![3];
65+
let ew_lhs_tensor = CausalTensor::new(ew_lhs_data, ew_lhs_shape).unwrap();
66+
67+
let ew_rhs_data = vec![4.0, 5.0, 6.0];
68+
let ew_rhs_shape = vec![3];
69+
let ew_rhs_tensor = CausalTensor::new(ew_rhs_data, ew_rhs_shape).unwrap();
70+
71+
println!("LHS Tensor for Element-wise Product:\n{:?}", ew_lhs_tensor);
72+
println!("RHS Tensor for Element-wise Product:\n{:?}", ew_rhs_tensor);
73+
74+
let result_ew_prod = CausalTensor::ein_sum(&EinSumOp::element_wise_product(
75+
ew_lhs_tensor.clone(),
76+
ew_rhs_tensor.clone(),
77+
))
78+
.unwrap();
79+
println!("Result of Element-wise Product:\n{:?}", result_ew_prod);
80+
let expected_ew_prod = CausalTensor::new(vec![4.0, 10.0, 18.0], vec![3]).unwrap();
81+
assert_eq!(result_ew_prod, expected_ew_prod);
82+
83+
// Example 5: Batch Matrix Multiplication
84+
println!("\n--- Example 5: Batch Matrix Multiplication ---");
85+
// Batch of two 2x2 matrices
86+
let bmm_lhs_data = vec![
87+
1.0, 2.0, 3.0, 4.0, // First 2x2 matrix
88+
5.0, 6.0, 7.0, 8.0, // Second 2x2 matrix
89+
];
90+
let bmm_lhs_shape = vec![2, 2, 2]; // 2 batches, 2 rows, 2 cols
91+
let bmm_lhs_tensor = CausalTensor::new(bmm_lhs_data, bmm_lhs_shape).unwrap();
92+
93+
let bmm_rhs_data = vec![
94+
9.0, 10.0, 11.0, 12.0, // First 2x2 matrix
95+
13.0, 14.0, 15.0, 16.0, // Second 2x2 matrix
96+
];
97+
let bmm_rhs_shape = vec![2, 2, 2]; // 2 batches, 2 rows, 2 cols
98+
let bmm_rhs_tensor = CausalTensor::new(bmm_rhs_data, bmm_rhs_shape).unwrap();
99+
100+
println!("LHS Tensor for Batch MatMul:\n{:?}", bmm_lhs_tensor);
101+
println!("RHS Tensor for Batch MatMul:\n{:?}", bmm_rhs_tensor);
102+
103+
let result_bmm = CausalTensor::ein_sum(&EinSumOp::batch_mat_mul(
104+
bmm_lhs_tensor.clone(),
105+
bmm_rhs_tensor.clone(),
106+
))
107+
.unwrap();
108+
println!("Result of Batch Matrix Multiplication:\n{:?}", result_bmm);
109+
let expected_bmm = CausalTensor::new(
110+
vec![
111+
31.0, 34.0, 71.0, 78.0, // First 2x2 matrix result
112+
155.0, 166.0, 211.0, 226.0, // Second 2x2 matrix result
113+
],
114+
vec![2, 2, 2],
115+
)
116+
.unwrap();
117+
assert_eq!(result_bmm, expected_bmm);
118+
}

deep_causality_tensor/src/errors/causal_tensor_error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
* SPDX-License-Identifier: MIT
33
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
44
*/
5+
use crate::EinSumValidationError;
56
use std::error::Error;
67

78
/// Errors that can occur during tensor operations.
@@ -15,6 +16,8 @@ pub enum CausalTensorError {
1516
InvalidOperation,
1617
UnorderableValue,
1718
InvalidParameter(String),
19+
/// Encapsulates errors specific to EinSum AST validation and execution.
20+
EinSumError(EinSumValidationError),
1821
}
1922

2023
impl Error for CausalTensorError {}
@@ -44,6 +47,9 @@ impl std::fmt::Display for CausalTensorError {
4447
CausalTensorError::DivisionByZero => {
4548
write!(f, "CausalTensorError: Division by zero error")
4649
}
50+
CausalTensorError::EinSumError(e) => {
51+
write!(f, "CausalTensorError: EinSumError: {}", e)
52+
}
4753
}
4854
}
4955
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* SPDX-License-Identifier: MIT
3+
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4+
*/
5+
use std::error::Error;
6+
7+
/// Specific errors that can occur during EinSum AST validation or execution.
8+
#[derive(Debug, Clone, PartialOrd, PartialEq)]
9+
pub enum EinSumValidationError {
10+
/// Indicates an incorrect number of child nodes for an AST operation.
11+
InvalidNumberOfChildren { expected: usize, found: usize },
12+
/// Indicates an issue with the specified axes for an operation (e.g., out of bounds, duplicate).
13+
InvalidAxesSpecification { message: String },
14+
/// Indicates an operation that is not yet implemented or is used in an unsupported context.
15+
UnsupportedOperation { operation: String },
16+
/// Indicates a mismatch in tensor shapes that prevents an operation from proceeding.
17+
ShapeMismatch { message: String },
18+
/// Indicates that a tensor has an unexpected rank for a given operation.
19+
RankMismatch { expected: usize, found: usize },
20+
}
21+
22+
impl std::fmt::Display for EinSumValidationError {
23+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
24+
match self {
25+
EinSumValidationError::InvalidNumberOfChildren { expected, found } => {
26+
write!(
27+
f,
28+
"EinSumValidationError: Invalid number of children. Expected {}, found {}",
29+
expected, found
30+
)
31+
}
32+
EinSumValidationError::InvalidAxesSpecification { message } => {
33+
write!(
34+
f,
35+
"EinSumValidationError: Invalid axes specification: {}",
36+
message
37+
)
38+
}
39+
EinSumValidationError::UnsupportedOperation { operation } => {
40+
write!(
41+
f,
42+
"EinSumValidationError: Unsupported operation: {}",
43+
operation
44+
)
45+
}
46+
EinSumValidationError::ShapeMismatch { message } => {
47+
write!(f, "EinSumValidationError: Shape mismatch: {}", message)
48+
}
49+
EinSumValidationError::RankMismatch { expected, found } => {
50+
write!(
51+
f,
52+
"EinSumValidationError: Rank mismatch. Expected {}, found {}",
53+
expected, found
54+
)
55+
}
56+
}
57+
}
58+
}
59+
60+
impl Error for EinSumValidationError {}

0 commit comments

Comments
 (0)