|
| 1 | +/* |
| 2 | + * SPDX-License-Identifier: MIT |
| 3 | + * Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved. |
| 4 | + */ |
| 5 | +use deep_causality_tensor::{CausalTensor, EinSumOp}; |
| 6 | + |
| 7 | +fn main() { |
| 8 | + // Example 1: Matrix Multiplication |
| 9 | + println!("--- Example 1: Matrix Multiplication ---"); |
| 10 | + let lhs_data = vec![1.0, 2.0, 3.0, 4.0]; |
| 11 | + let lhs_shape = vec![2, 2]; |
| 12 | + let lhs_tensor = CausalTensor::new(lhs_data, lhs_shape).unwrap(); |
| 13 | + |
| 14 | + let rhs_data = vec![5.0, 6.0, 7.0, 8.0]; |
| 15 | + let rhs_shape = vec![2, 2]; |
| 16 | + let rhs_tensor = CausalTensor::new(rhs_data, rhs_shape).unwrap(); |
| 17 | + |
| 18 | + println!("LHS Tensor:\n{:?}", lhs_tensor); |
| 19 | + println!("RHS Tensor:\n{:?}", rhs_tensor); |
| 20 | + |
| 21 | + let result_mat_mul = |
| 22 | + CausalTensor::ein_sum(&EinSumOp::mat_mul(lhs_tensor.clone(), rhs_tensor.clone())).unwrap(); |
| 23 | + println!("Result of Matrix Multiplication:\n{:?}", result_mat_mul); |
| 24 | + let expected_mat_mul = CausalTensor::new(vec![19.0, 22.0, 43.0, 50.0], vec![2, 2]).unwrap(); |
| 25 | + assert_eq!(result_mat_mul, expected_mat_mul); |
| 26 | + |
| 27 | + // Example 2: Dot Product |
| 28 | + println!("\n--- Example 2: Dot Product ---"); |
| 29 | + let vec1_data = vec![1.0, 2.0, 3.0]; |
| 30 | + let vec1_shape = vec![3]; |
| 31 | + let vec1_tensor = CausalTensor::new(vec1_data, vec1_shape).unwrap(); |
| 32 | + |
| 33 | + let vec2_data = vec![4.0, 5.0, 6.0]; |
| 34 | + let vec2_shape = vec![3]; |
| 35 | + let vec2_tensor = CausalTensor::new(vec2_data, vec2_shape).unwrap(); |
| 36 | + |
| 37 | + println!("Vector 1:\n{:?}", vec1_tensor); |
| 38 | + println!("Vector 2:\n{:?}", vec2_tensor); |
| 39 | + |
| 40 | + let result_dot_prod = CausalTensor::ein_sum(&EinSumOp::dot_prod( |
| 41 | + vec1_tensor.clone(), |
| 42 | + vec2_tensor.clone(), |
| 43 | + )) |
| 44 | + .unwrap(); |
| 45 | + println!("Result of Dot Product:\n{:?}", result_dot_prod); |
| 46 | + let expected_dot_prod = CausalTensor::new(vec![32.0], vec![]).unwrap(); |
| 47 | + assert_eq!(result_dot_prod, expected_dot_prod); |
| 48 | + |
| 49 | + // Example 3: Trace |
| 50 | + println!("\n--- Example 3: Trace ---"); |
| 51 | + let trace_data = vec![1.0, 2.0, 3.0, 4.0]; |
| 52 | + let trace_shape = vec![2, 2]; |
| 53 | + let trace_tensor = CausalTensor::new(trace_data, trace_shape).unwrap(); |
| 54 | + |
| 55 | + println!("Tensor for Trace:\n{:?}", trace_tensor); |
| 56 | + let result_trace = CausalTensor::ein_sum(&EinSumOp::trace(trace_tensor.clone(), 0, 1)).unwrap(); |
| 57 | + println!("Result of Trace (axes 0, 1):\n{:?}", result_trace); |
| 58 | + let expected_trace = CausalTensor::new(vec![5.0], vec![]).unwrap(); |
| 59 | + assert_eq!(result_trace, expected_trace); |
| 60 | + |
| 61 | + // Example 4: Element-wise Product |
| 62 | + println!("\n--- Example 4: Element-wise Product ---"); |
| 63 | + let ew_lhs_data = vec![1.0, 2.0, 3.0]; |
| 64 | + let ew_lhs_shape = vec![3]; |
| 65 | + let ew_lhs_tensor = CausalTensor::new(ew_lhs_data, ew_lhs_shape).unwrap(); |
| 66 | + |
| 67 | + let ew_rhs_data = vec![4.0, 5.0, 6.0]; |
| 68 | + let ew_rhs_shape = vec![3]; |
| 69 | + let ew_rhs_tensor = CausalTensor::new(ew_rhs_data, ew_rhs_shape).unwrap(); |
| 70 | + |
| 71 | + println!("LHS Tensor for Element-wise Product:\n{:?}", ew_lhs_tensor); |
| 72 | + println!("RHS Tensor for Element-wise Product:\n{:?}", ew_rhs_tensor); |
| 73 | + |
| 74 | + let result_ew_prod = CausalTensor::ein_sum(&EinSumOp::element_wise_product( |
| 75 | + ew_lhs_tensor.clone(), |
| 76 | + ew_rhs_tensor.clone(), |
| 77 | + )) |
| 78 | + .unwrap(); |
| 79 | + println!("Result of Element-wise Product:\n{:?}", result_ew_prod); |
| 80 | + let expected_ew_prod = CausalTensor::new(vec![4.0, 10.0, 18.0], vec![3]).unwrap(); |
| 81 | + assert_eq!(result_ew_prod, expected_ew_prod); |
| 82 | + |
| 83 | + // Example 5: Batch Matrix Multiplication |
| 84 | + println!("\n--- Example 5: Batch Matrix Multiplication ---"); |
| 85 | + // Batch of two 2x2 matrices |
| 86 | + let bmm_lhs_data = vec![ |
| 87 | + 1.0, 2.0, 3.0, 4.0, // First 2x2 matrix |
| 88 | + 5.0, 6.0, 7.0, 8.0, // Second 2x2 matrix |
| 89 | + ]; |
| 90 | + let bmm_lhs_shape = vec![2, 2, 2]; // 2 batches, 2 rows, 2 cols |
| 91 | + let bmm_lhs_tensor = CausalTensor::new(bmm_lhs_data, bmm_lhs_shape).unwrap(); |
| 92 | + |
| 93 | + let bmm_rhs_data = vec![ |
| 94 | + 9.0, 10.0, 11.0, 12.0, // First 2x2 matrix |
| 95 | + 13.0, 14.0, 15.0, 16.0, // Second 2x2 matrix |
| 96 | + ]; |
| 97 | + let bmm_rhs_shape = vec![2, 2, 2]; // 2 batches, 2 rows, 2 cols |
| 98 | + let bmm_rhs_tensor = CausalTensor::new(bmm_rhs_data, bmm_rhs_shape).unwrap(); |
| 99 | + |
| 100 | + println!("LHS Tensor for Batch MatMul:\n{:?}", bmm_lhs_tensor); |
| 101 | + println!("RHS Tensor for Batch MatMul:\n{:?}", bmm_rhs_tensor); |
| 102 | + |
| 103 | + let result_bmm = CausalTensor::ein_sum(&EinSumOp::batch_mat_mul( |
| 104 | + bmm_lhs_tensor.clone(), |
| 105 | + bmm_rhs_tensor.clone(), |
| 106 | + )) |
| 107 | + .unwrap(); |
| 108 | + println!("Result of Batch Matrix Multiplication:\n{:?}", result_bmm); |
| 109 | + let expected_bmm = CausalTensor::new( |
| 110 | + vec![ |
| 111 | + 31.0, 34.0, 71.0, 78.0, // First 2x2 matrix result |
| 112 | + 155.0, 166.0, 211.0, 226.0, // Second 2x2 matrix result |
| 113 | + ], |
| 114 | + vec![2, 2, 2], |
| 115 | + ) |
| 116 | + .unwrap(); |
| 117 | + assert_eq!(result_bmm, expected_bmm); |
| 118 | +} |
0 commit comments