Skip to content

Commit 7b5d72c

Browse files
committed
feat(deep_causality_tensor):
Refactored code organization and improved documentation of public API. Signed-off-by: Marvin Hansen <[email protected]>
1 parent f8cee3c commit 7b5d72c

File tree

6 files changed

+69
-7
lines changed

6 files changed

+69
-7
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* SPDX-License-Identifier: MIT
3+
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4+
*/
5+
use crate::{CausalTensor, CausalTensorError};
6+
use std::ops::Mul;
7+
8+
impl<T> CausalTensor<T>
9+
where
10+
T: Clone + Default + PartialOrd + Mul<Output = T>,
11+
{
12+
/// Computes the tensor product (also known as the outer product) of two `CausalTensor`s.
13+
///
14+
/// The tensor product combines two tensors into a new tensor whose rank is the sum of
15+
/// the ranks of the input tensors, and whose shape is the concatenation of their shapes.
16+
/// Each element of the resulting tensor is the product of an element from the left-hand side
17+
/// tensor and an element from the right-hand side tensor.
18+
///
19+
/// This operation is fundamental in linear algebra and tensor calculus, effectively
20+
/// creating all possible pairwise products between elements of the input tensors.
21+
///
22+
/// # Arguments
23+
///
24+
/// * `rhs` - The right-hand side `CausalTensor`.
25+
///
26+
/// # Returns
27+
///
28+
/// A `Result` which is:
29+
/// - `Ok(CausalTensor<T>)` containing the result of the tensor product.
30+
/// - `Err(CausalTensorError)` if an error occurs during the operation (e.g., memory allocation).
31+
///
32+
/// # Errors
33+
///
34+
/// This method can return `CausalTensorError` if the underlying `tensor_product_impl`
35+
/// encounters an issue, such as a failure during new tensor creation.
36+
///
37+
/// # Examples
38+
///
39+
/// ```
40+
/// use deep_causality_tensor::CausalTensor;
41+
///
42+
/// let lhs = CausalTensor::new(vec![1.0, 2.0], vec![2]).unwrap(); // Shape [2]
43+
/// let rhs = CausalTensor::new(vec![3.0, 4.0, 5.0], vec![3]).unwrap(); // Shape [3]
44+
///
45+
/// // Expected result:
46+
/// // [[1*3, 1*4, 1*5],
47+
/// // [2*3, 2*4, 2*5]]
48+
/// // which is [[3.0, 4.0, 5.0], [6.0, 8.0, 10.0]] with shape [2, 3]
49+
/// let result = lhs.tensor_product(&rhs).unwrap();
50+
///
51+
/// assert_eq!(result.shape(), &[2, 3]);
52+
/// assert_eq!(result.as_slice(), &[3.0, 4.0, 5.0, 6.0, 8.0, 10.0]);
53+
///
54+
/// // Tensor product with a scalar
55+
/// let scalar = CausalTensor::new(vec![10.0], vec![]).unwrap(); // Shape []
56+
/// let vector = CausalTensor::new(vec![1.0, 2.0], vec![2]).unwrap(); // Shape [2]
57+
/// let result_scalar_vec = scalar.tensor_product(&vector).unwrap();
58+
/// assert_eq!(result_scalar_vec.shape(), &[2]);
59+
/// assert_eq!(result_scalar_vec.as_slice(), &[10.0, 20.0]);
60+
/// ```
61+
pub fn tensor_product(
62+
&self,
63+
rhs: &CausalTensor<T>,
64+
) -> Result<CausalTensor<T>, CausalTensorError> {
65+
self.tensor_product_impl(rhs)
66+
}
67+
}

deep_causality_tensor/src/types/causal_tensor/api/api_tensor_reduction.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ impl<T> CausalTensor<T>
99
where
1010
T: Clone + Default + PartialOrd,
1111
{
12-
// --- Reduction Operations ---
13-
1412
/// Sums the elements along one or more specified axes.
1513
///
1614
/// The dimensions corresponding to the `axes` provided will be removed from the

deep_causality_tensor/src/types/causal_tensor/api/api_tensor_shape.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ impl<T> CausalTensor<T>
88
where
99
T: Clone + Default + PartialOrd,
1010
{
11-
// --- Shape Manipulation ---
12-
1311
/// Returns a new tensor with the same data but a different shape.
1412
///
1513
/// This is a metadata-only operation; it creates a new `CausalTensor` with a cloned copy

deep_causality_tensor/src/types/causal_tensor/api/api_tensor_view.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ impl<T> CausalTensor<T>
88
where
99
T: Clone + Default + PartialOrd,
1010
{
11-
// --- View Operations ---
12-
1311
/// Creates a new tensor representing a slice of the original tensor along a specified axis.
1412
///
1513
/// This operation extracts a sub-tensor where one dimension has been fixed to a specific index.

deep_causality_tensor/src/types/causal_tensor/api/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
44
*/
55
mod api_ein_sum;
6+
mod api_tensor_product;
67
mod api_tensor_reduction;
78
mod api_tensor_shape;
89
mod api_tensor_view;

deep_causality_tensor/src/types/causal_tensor/op_tensor_product/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ impl<T> CausalTensor<T>
99
where
1010
T: Clone + Default + PartialOrd + Mul<Output = T>,
1111
{
12-
pub fn tensor_product(
12+
pub(super) fn tensor_product_impl(
1313
&self,
1414
rhs: &CausalTensor<T>,
1515
) -> Result<CausalTensor<T>, CausalTensorError> {

0 commit comments

Comments
 (0)