@@ -11,7 +11,29 @@ impl<T> CausalTensor<T>
1111where
1212 T : Clone + Default + PartialOrd + Add < Output = T > + Mul < Output = T > ,
1313{
14- /// Helper to get two operands from the AST children.
14+ /// Extracts two operands from the provided Abstract Syntax Tree (AST) children.
15+ ///
16+ /// This helper function is used to retrieve the left-hand side (LHS) and right-hand side (RHS)
17+ /// `CausalTensor` operands from a slice of `EinSumAST` nodes. It expects exactly two children
18+ /// in the AST slice.
19+ ///
20+ /// # Arguments
21+ ///
22+ /// * `children` - A slice of `EinSumAST<T>` representing the children nodes of an AST operation.
23+ /// Expected to contain exactly two elements, each resolving to a `CausalTensor<T>`.
24+ ///
25+ /// # Returns
26+ ///
27+ /// A `Result` which is:
28+ /// - `Ok((CausalTensor<T>, CausalTensor<T>))` containing the LHS and RHS tensors if successful.
29+ /// - `Err(CausalTensorError)` if the number of children is not two, or if `execute_ein_sum`
30+ /// fails for either child.
31+ ///
32+ /// # Errors
33+ ///
34+ /// Returns `CausalTensorError::EinSumError(EinSumValidationError::InvalidNumberOfChildren)`
35+ /// if `children.len()` is not equal to 2.
36+ /// Returns errors propagated from `CausalTensor::execute_ein_sum`.
1537 pub ( super ) fn get_binary_operands (
1638 children : & [ EinSumAST < T > ] ,
1739 ) -> Result < ( CausalTensor < T > , CausalTensor < T > ) , CausalTensorError > {
2850 Ok ( ( lhs, rhs) )
2951 }
3052
31- /// Helper to get a single operand from the AST children.
53+ /// Extracts a single operand from the provided Abstract Syntax Tree (AST) children.
54+ ///
55+ /// This helper function is used to retrieve a single `CausalTensor` operand from a slice
56+ /// of `EinSumAST` nodes. It expects exactly one child in the AST slice.
57+ ///
58+ /// # Arguments
59+ ///
60+ /// * `children` - A slice of `EinSumAST<T>` representing the children nodes of an AST operation.
61+ /// Expected to contain exactly one element, resolving to a `CausalTensor<T>`.
62+ ///
63+ /// # Returns
64+ ///
65+ /// A `Result` which is:
66+ /// - `Ok(CausalTensor<T>)` containing the single operand tensor if successful.
67+ /// - `Err(CausalTensorError)` if the number of children is not one, or if `execute_ein_sum`
68+ /// fails for the child.
69+ ///
70+ /// # Errors
71+ ///
72+ /// Returns `CausalTensorError::EinSumError(EinSumValidationError::InvalidNumberOfChildren)`
73+ /// if `children.len()` is not equal to 1.
74+ /// Returns errors propagated from `CausalTensor::execute_ein_sum`.
3275 pub ( super ) fn get_unary_operand (
3376 children : & [ EinSumAST < T > ] ,
3477 ) -> Result < CausalTensor < T > , CausalTensorError > {
4386 CausalTensor :: execute_ein_sum ( & children[ 0 ] )
4487 }
4588
46- /// Private method for generic tensor contraction.
47- /// This optimized version uses permutation and reshaping to reduce contraction to matrix multiplication.
89+ /// Performs a generic tensor contraction between two `CausalTensor`s.
90+ ///
91+ /// This method implements an optimized tensor contraction by leveraging permutation and
92+ /// reshaping operations to reduce the problem to a standard 2D matrix multiplication.
93+ /// It identifies common axes between the two tensors and sums over their products.
94+ ///
95+ /// # Arguments
96+ ///
97+ /// * `lhs` - The left-hand side `CausalTensor`.
98+ /// * `rhs` - The right-hand side `CausalTensor`.
99+ /// * `lhs_contract_axes` - A slice of `usize` indicating the axes of `lhs` to contract over.
100+ /// * `rhs_contract_axes` - A slice of `usize` indicating the axes of `rhs` to contract over.
101+ ///
102+ /// # Returns
103+ ///
104+ /// A `Result` which is:
105+ /// - `Ok(CausalTensor<T>)` containing the result of the tensor contraction.
106+ /// - `Err(CausalTensorError)` if validation fails or an underlying operation encounters an error.
107+ ///
108+ /// # Errors
109+ ///
110+ /// Returns `CausalTensorError::EinSumError` if:
111+ /// - The number of `lhs_contract_axes` and `rhs_contract_axes` do not match.
112+ /// - Any specified axis is out of bounds for its respective tensor.
113+ /// - The dimensions of the contracted axes in `lhs` and `rhs` do not match.
114+ /// - Errors are propagated from `permute_axes`, `reshape`, or `mat_mul_2d`.
48115 pub ( super ) fn contract (
49116 lhs : & CausalTensor < T > ,
50117 rhs : & CausalTensor < T > ,
@@ -126,7 +193,28 @@ where
126193 result_matrix. reshape ( & final_shape)
127194 }
128195
129- /// Private helper for 2D matrix multiplication.
196+ /// Performs 2D matrix multiplication between two `CausalTensor`s.
197+ ///
198+ /// This private helper function computes the matrix product of two rank-2 tensors.
199+ /// It validates that both input tensors are indeed 2D and that their inner dimensions
200+ /// are compatible for multiplication (i.e., `lhs.cols == rhs.rows`).
201+ ///
202+ /// # Arguments
203+ ///
204+ /// * `lhs` - The left-hand side `CausalTensor` (matrix).
205+ /// * `rhs` - The right-hand side `CausalTensor` (matrix).
206+ ///
207+ /// # Returns
208+ ///
209+ /// A `Result` which is:
210+ /// - `Ok(CausalTensor<T>)` containing the resulting product matrix.
211+ /// - `Err(CausalTensorError)` if validation fails.
212+ ///
213+ /// # Errors
214+ ///
215+ /// Returns `CausalTensorError::EinSumError` if:
216+ /// - Either `lhs` or `rhs` is not a 2D tensor (rank mismatch).
217+ /// - The inner dimensions of `lhs` and `rhs` do not match (shape mismatch).
130218 pub ( super ) fn mat_mul_2d (
131219 lhs : & CausalTensor < T > ,
132220 rhs : & CausalTensor < T > ,
@@ -179,15 +267,57 @@ where
179267 CausalTensor :: new ( result_data, vec ! [ m, n] )
180268 }
181269
182- /// Private method for element-wise multiplication.
270+ /// Performs element-wise multiplication (Hadamard product) between two `CausalTensor`s.
271+ ///
272+ /// This private method multiplies corresponding elements of two tensors. It leverages
273+ /// the `broadcast_op` method to handle broadcasting rules if the tensor shapes are not identical.
274+ ///
275+ /// # Arguments
276+ ///
277+ /// * `lhs` - The left-hand side `CausalTensor`.
278+ /// * `rhs` - The right-hand side `CausalTensor`.
279+ ///
280+ /// # Returns
281+ ///
282+ /// A `Result` which is:
283+ /// - `Ok(CausalTensor<T>)` containing the result of the element-wise multiplication.
284+ /// - `Err(CausalTensorError)` if broadcasting fails or an element-wise operation encounters an error.
285+ ///
286+ /// # Errors
287+ ///
288+ /// Returns errors propagated from `CausalTensor::broadcast_op`.
183289 pub ( super ) fn element_wise_mul (
184290 lhs : & CausalTensor < T > ,
185291 rhs : & CausalTensor < T > ,
186292 ) -> Result < CausalTensor < T > , CausalTensorError > {
187293 lhs. broadcast_op ( rhs, |a, b| Ok ( a * b) )
188294 }
189295
190- /// Private method for tracing (summing over diagonal axes).
296+ /// Computes the trace of a `CausalTensor` over two specified axes.
297+ ///
298+ /// The trace operation sums the elements along the diagonal of a 2D slice
299+ /// defined by `axis1` and `axis2`. If the tensor has more than two dimensions,
300+ /// this operation is applied to all 2D slices formed by the specified axes,
301+ /// effectively reducing the tensor's rank by two.
302+ ///
303+ /// # Arguments
304+ ///
305+ /// * `tensor` - The input `CausalTensor`.
306+ /// * `axis1` - The first axis to trace over.
307+ /// * `axis2` - The second axis to trace over.
308+ ///
309+ /// # Returns
310+ ///
311+ /// A `Result` which is:
312+ /// - `Ok(CausalTensor<T>)` containing the result of the trace operation.
313+ /// - `Err(CausalTensorError)` if validation fails.
314+ ///
315+ /// # Errors
316+ ///
317+ /// Returns `CausalTensorError::EinSumError` if:
318+ /// - `axis1` or `axis2` are out of bounds for the tensor's dimensions.
319+ /// - `axis1` and `axis2` are identical.
320+ /// - The dimensions of `axis1` and `axis2` are not equal (shape mismatch).
191321 pub ( super ) fn trace (
192322 tensor : & CausalTensor < T > ,
193323 axis1 : usize ,
@@ -259,7 +389,30 @@ where
259389 Ok ( result_tensor)
260390 }
261391
262- /// Private method for extracting a diagonal.
392+ /// Extracts the diagonal elements of a `CausalTensor` over two specified axes.
393+ ///
394+ /// This private method extracts the diagonal elements from 2D slices of the input tensor
395+ /// defined by `axis1` and `axis2`. The resulting tensor will have a rank reduced by one
396+ /// compared to the input tensor, with the new last dimension representing the diagonal.
397+ ///
398+ /// # Arguments
399+ ///
400+ /// * `tensor` - The input `CausalTensor`.
401+ /// * `axis1` - The first axis defining the 2D plane from which to extract the diagonal.
402+ /// * `axis2` - The second axis defining the 2D plane from which to extract the diagonal.
403+ ///
404+ /// # Returns
405+ ///
406+ /// A `Result` which is:
407+ /// - `Ok(CausalTensor<T>)` containing the tensor with extracted diagonal elements.
408+ /// - `Err(CausalTensorError)` if validation fails.
409+ ///
410+ /// # Errors
411+ ///
412+ /// Returns `CausalTensorError::EinSumError` if:
413+ /// - `axis1` or `axis2` are out of bounds for the tensor's dimensions.
414+ /// - `axis1` and `axis2` are identical.
415+ /// - The dimensions of `axis1` and `axis2` are not equal (shape mismatch).
263416 pub ( super ) fn diagonal (
264417 tensor : & CausalTensor < T > ,
265418 axis1 : usize ,
@@ -329,6 +482,29 @@ where
329482 CausalTensor :: new ( result_data, new_shape)
330483 }
331484
485+ /// Performs batch matrix multiplication between two `CausalTensor`s.
486+ ///
487+ /// This method expects both input tensors to have at least 3 dimensions, where the first
488+ /// dimension represents the batch size. It performs a 2D matrix multiplication for each
489+ /// corresponding pair of matrices within the batch and stacks the results.
490+ ///
491+ /// # Arguments
492+ ///
493+ /// * `lhs` - The left-hand side `CausalTensor` with a batch dimension.
494+ /// * `rhs` - The right-hand side `CausalTensor` with a batch dimension.
495+ ///
496+ /// # Returns
497+ ///
498+ /// A `Result` which is:
499+ /// - `Ok(CausalTensor<T>)` containing the result of the batch matrix multiplication.
500+ /// - `Err(CausalTensorError)` if validation fails or an underlying operation encounters an error.
501+ ///
502+ /// # Errors
503+ ///
504+ /// Returns `CausalTensorError::EinSumError` if:
505+ /// - Either `lhs` or `rhs` has fewer than 3 dimensions (rank mismatch).
506+ /// - The batch sizes of `lhs` and `rhs` do not match (shape mismatch).
507+ /// - Errors are propagated from `slice`, `mat_mul_2d`, or `stack`.
332508 pub ( super ) fn batch_mat_mul (
333509 lhs : CausalTensor < T > ,
334510 rhs : CausalTensor < T > ,
0 commit comments