Skip to content

Commit dd4e700

Browse files
committed
doc(deep_causality_tensor): Documented implementation for Einstein Sum Convention.
Signed-off-by: Marvin Hansen <[email protected]>
1 parent beac1e6 commit dd4e700

File tree

3 files changed

+358
-8
lines changed

3 files changed

+358
-8
lines changed

deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum/ein_sum_impl.rs

Lines changed: 184 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,29 @@ impl<T> CausalTensor<T>
1111
where
1212
T: Clone + Default + PartialOrd + Add<Output = T> + Mul<Output = T>,
1313
{
14-
/// Helper to get two operands from the AST children.
14+
/// Extracts two operands from the provided Abstract Syntax Tree (AST) children.
15+
///
16+
/// This helper function is used to retrieve the left-hand side (LHS) and right-hand side (RHS)
17+
/// `CausalTensor` operands from a slice of `EinSumAST` nodes. It expects exactly two children
18+
/// in the AST slice.
19+
///
20+
/// # Arguments
21+
///
22+
/// * `children` - A slice of `EinSumAST<T>` representing the children nodes of an AST operation.
23+
/// Expected to contain exactly two elements, each resolving to a `CausalTensor<T>`.
24+
///
25+
/// # Returns
26+
///
27+
/// A `Result` which is:
28+
/// - `Ok((CausalTensor<T>, CausalTensor<T>))` containing the LHS and RHS tensors if successful.
29+
/// - `Err(CausalTensorError)` if the number of children is not two, or if `execute_ein_sum`
30+
/// fails for either child.
31+
///
32+
/// # Errors
33+
///
34+
/// Returns `CausalTensorError::EinSumError(EinSumValidationError::InvalidNumberOfChildren)`
35+
/// if `children.len()` is not equal to 2.
36+
/// Returns errors propagated from `CausalTensor::execute_ein_sum`.
1537
pub(super) fn get_binary_operands(
1638
children: &[EinSumAST<T>],
1739
) -> Result<(CausalTensor<T>, CausalTensor<T>), CausalTensorError> {
@@ -28,7 +50,28 @@ where
2850
Ok((lhs, rhs))
2951
}
3052

31-
/// Helper to get a single operand from the AST children.
53+
/// Extracts a single operand from the provided Abstract Syntax Tree (AST) children.
54+
///
55+
/// This helper function is used to retrieve a single `CausalTensor` operand from a slice
56+
/// of `EinSumAST` nodes. It expects exactly one child in the AST slice.
57+
///
58+
/// # Arguments
59+
///
60+
/// * `children` - A slice of `EinSumAST<T>` representing the children nodes of an AST operation.
61+
/// Expected to contain exactly one element, resolving to a `CausalTensor<T>`.
62+
///
63+
/// # Returns
64+
///
65+
/// A `Result` which is:
66+
/// - `Ok(CausalTensor<T>)` containing the single operand tensor if successful.
67+
/// - `Err(CausalTensorError)` if the number of children is not one, or if `execute_ein_sum`
68+
/// fails for the child.
69+
///
70+
/// # Errors
71+
///
72+
/// Returns `CausalTensorError::EinSumError(EinSumValidationError::InvalidNumberOfChildren)`
73+
/// if `children.len()` is not equal to 1.
74+
/// Returns errors propagated from `CausalTensor::execute_ein_sum`.
3275
pub(super) fn get_unary_operand(
3376
children: &[EinSumAST<T>],
3477
) -> Result<CausalTensor<T>, CausalTensorError> {
@@ -43,8 +86,32 @@ where
4386
CausalTensor::execute_ein_sum(&children[0])
4487
}
4588

46-
/// Private method for generic tensor contraction.
47-
/// This optimized version uses permutation and reshaping to reduce contraction to matrix multiplication.
89+
/// Performs a generic tensor contraction between two `CausalTensor`s.
90+
///
91+
/// This method implements an optimized tensor contraction by leveraging permutation and
92+
/// reshaping operations to reduce the problem to a standard 2D matrix multiplication.
93+
/// It identifies common axes between the two tensors and sums over their products.
94+
///
95+
/// # Arguments
96+
///
97+
/// * `lhs` - The left-hand side `CausalTensor`.
98+
/// * `rhs` - The right-hand side `CausalTensor`.
99+
/// * `lhs_contract_axes` - A slice of `usize` indicating the axes of `lhs` to contract over.
100+
/// * `rhs_contract_axes` - A slice of `usize` indicating the axes of `rhs` to contract over.
101+
///
102+
/// # Returns
103+
///
104+
/// A `Result` which is:
105+
/// - `Ok(CausalTensor<T>)` containing the result of the tensor contraction.
106+
/// - `Err(CausalTensorError)` if validation fails or an underlying operation encounters an error.
107+
///
108+
/// # Errors
109+
///
110+
/// Returns `CausalTensorError::EinSumError` if:
111+
/// - The number of `lhs_contract_axes` and `rhs_contract_axes` do not match.
112+
/// - Any specified axis is out of bounds for its respective tensor.
113+
/// - The dimensions of the contracted axes in `lhs` and `rhs` do not match.
114+
/// - Errors are propagated from `permute_axes`, `reshape`, or `mat_mul_2d`.
48115
pub(super) fn contract(
49116
lhs: &CausalTensor<T>,
50117
rhs: &CausalTensor<T>,
@@ -126,7 +193,28 @@ where
126193
result_matrix.reshape(&final_shape)
127194
}
128195

129-
/// Private helper for 2D matrix multiplication.
196+
/// Performs 2D matrix multiplication between two `CausalTensor`s.
197+
///
198+
/// This private helper function computes the matrix product of two rank-2 tensors.
199+
/// It validates that both input tensors are indeed 2D and that their inner dimensions
200+
/// are compatible for multiplication (i.e., `lhs.cols == rhs.rows`).
201+
///
202+
/// # Arguments
203+
///
204+
/// * `lhs` - The left-hand side `CausalTensor` (matrix).
205+
/// * `rhs` - The right-hand side `CausalTensor` (matrix).
206+
///
207+
/// # Returns
208+
///
209+
/// A `Result` which is:
210+
/// - `Ok(CausalTensor<T>)` containing the resulting product matrix.
211+
/// - `Err(CausalTensorError)` if validation fails.
212+
///
213+
/// # Errors
214+
///
215+
/// Returns `CausalTensorError::EinSumError` if:
216+
/// - Either `lhs` or `rhs` is not a 2D tensor (rank mismatch).
217+
/// - The inner dimensions of `lhs` and `rhs` do not match (shape mismatch).
130218
pub(super) fn mat_mul_2d(
131219
lhs: &CausalTensor<T>,
132220
rhs: &CausalTensor<T>,
@@ -179,15 +267,57 @@ where
179267
CausalTensor::new(result_data, vec![m, n])
180268
}
181269

182-
/// Private method for element-wise multiplication.
270+
/// Performs element-wise multiplication (Hadamard product) between two `CausalTensor`s.
271+
///
272+
/// This private method multiplies corresponding elements of two tensors. It leverages
273+
/// the `broadcast_op` method to handle broadcasting rules if the tensor shapes are not identical.
274+
///
275+
/// # Arguments
276+
///
277+
/// * `lhs` - The left-hand side `CausalTensor`.
278+
/// * `rhs` - The right-hand side `CausalTensor`.
279+
///
280+
/// # Returns
281+
///
282+
/// A `Result` which is:
283+
/// - `Ok(CausalTensor<T>)` containing the result of the element-wise multiplication.
284+
/// - `Err(CausalTensorError)` if broadcasting fails or an element-wise operation encounters an error.
285+
///
286+
/// # Errors
287+
///
288+
/// Returns errors propagated from `CausalTensor::broadcast_op`.
183289
pub(super) fn element_wise_mul(
184290
lhs: &CausalTensor<T>,
185291
rhs: &CausalTensor<T>,
186292
) -> Result<CausalTensor<T>, CausalTensorError> {
187293
lhs.broadcast_op(rhs, |a, b| Ok(a * b))
188294
}
189295

190-
/// Private method for tracing (summing over diagonal axes).
296+
/// Computes the trace of a `CausalTensor` over two specified axes.
297+
///
298+
/// The trace operation sums the elements along the diagonal of a 2D slice
299+
/// defined by `axis1` and `axis2`. If the tensor has more than two dimensions,
300+
/// this operation is applied to all 2D slices formed by the specified axes,
301+
/// effectively reducing the tensor's rank by two.
302+
///
303+
/// # Arguments
304+
///
305+
/// * `tensor` - The input `CausalTensor`.
306+
/// * `axis1` - The first axis to trace over.
307+
/// * `axis2` - The second axis to trace over.
308+
///
309+
/// # Returns
310+
///
311+
/// A `Result` which is:
312+
/// - `Ok(CausalTensor<T>)` containing the result of the trace operation.
313+
/// - `Err(CausalTensorError)` if validation fails.
314+
///
315+
/// # Errors
316+
///
317+
/// Returns `CausalTensorError::EinSumError` if:
318+
/// - `axis1` or `axis2` are out of bounds for the tensor's dimensions.
319+
/// - `axis1` and `axis2` are identical.
320+
/// - The dimensions of `axis1` and `axis2` are not equal (shape mismatch).
191321
pub(super) fn trace(
192322
tensor: &CausalTensor<T>,
193323
axis1: usize,
@@ -259,7 +389,30 @@ where
259389
Ok(result_tensor)
260390
}
261391

262-
/// Private method for extracting a diagonal.
392+
/// Extracts the diagonal elements of a `CausalTensor` over two specified axes.
393+
///
394+
/// This private method extracts the diagonal elements from 2D slices of the input tensor
395+
/// defined by `axis1` and `axis2`. The resulting tensor will have a rank reduced by one
396+
/// compared to the input tensor, with the new last dimension representing the diagonal.
397+
///
398+
/// # Arguments
399+
///
400+
/// * `tensor` - The input `CausalTensor`.
401+
/// * `axis1` - The first axis defining the 2D plane from which to extract the diagonal.
402+
/// * `axis2` - The second axis defining the 2D plane from which to extract the diagonal.
403+
///
404+
/// # Returns
405+
///
406+
/// A `Result` which is:
407+
/// - `Ok(CausalTensor<T>)` containing the tensor with extracted diagonal elements.
408+
/// - `Err(CausalTensorError)` if validation fails.
409+
///
410+
/// # Errors
411+
///
412+
/// Returns `CausalTensorError::EinSumError` if:
413+
/// - `axis1` or `axis2` are out of bounds for the tensor's dimensions.
414+
/// - `axis1` and `axis2` are identical.
415+
/// - The dimensions of `axis1` and `axis2` are not equal (shape mismatch).
263416
pub(super) fn diagonal(
264417
tensor: &CausalTensor<T>,
265418
axis1: usize,
@@ -329,6 +482,29 @@ where
329482
CausalTensor::new(result_data, new_shape)
330483
}
331484

485+
/// Performs batch matrix multiplication between two `CausalTensor`s.
486+
///
487+
/// This method expects both input tensors to have at least 3 dimensions, where the first
488+
/// dimension represents the batch size. It performs a 2D matrix multiplication for each
489+
/// corresponding pair of matrices within the batch and stacks the results.
490+
///
491+
/// # Arguments
492+
///
493+
/// * `lhs` - The left-hand side `CausalTensor` with a batch dimension.
494+
/// * `rhs` - The right-hand side `CausalTensor` with a batch dimension.
495+
///
496+
/// # Returns
497+
///
498+
/// A `Result` which is:
499+
/// - `Ok(CausalTensor<T>)` containing the result of the batch matrix multiplication.
500+
/// - `Err(CausalTensorError)` if validation fails or an underlying operation encounters an error.
501+
///
502+
/// # Errors
503+
///
504+
/// Returns `CausalTensorError::EinSumError` if:
505+
/// - Either `lhs` or `rhs` has fewer than 3 dimensions (rank mismatch).
506+
/// - The batch sizes of `lhs` and `rhs` do not match (shape mismatch).
507+
/// - Errors are propagated from `slice`, `mat_mul_2d`, or `stack`.
332508
pub(super) fn batch_mat_mul(
333509
lhs: CausalTensor<T>,
334510
rhs: CausalTensor<T>,

0 commit comments

Comments
 (0)