Skip to content

Commit 038de90

Browse files
committed
Added new test cases to deep_causality_tensor/src/types/a
usal_tensor/op_tensor_ein_sum/ein_sum_impl_tests.rs to cover error scenarios for the batch_mat_mul function, specifically for rank mismatch and batch size mismatch. Signed-off-by: Marvin Hansen <[email protected]>
1 parent cdb8f94 commit 038de90

File tree

4 files changed

+360
-304
lines changed

4 files changed

+360
-304
lines changed

deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum/ein_sum_impl.rs

Lines changed: 10 additions & 304 deletions
Original file line numberDiff line numberDiff line change
@@ -511,14 +511,22 @@ where
511511
lhs: CausalTensor<T>,
512512
rhs: CausalTensor<T>,
513513
) -> Result<CausalTensor<T>, CausalTensorError> {
514-
if lhs.num_dim() < 3 || rhs.num_dim() < 3 {
514+
if lhs.num_dim() < 3 {
515515
return Err(CausalTensorError::EinSumError(
516516
EinSumValidationError::RankMismatch {
517-
expected: 3, // At least 3 dimensions for batch matmul (batch, rows, cols)
517+
expected: 3,
518518
found: lhs.num_dim(),
519519
},
520520
));
521521
}
522+
if rhs.num_dim() < 3 {
523+
return Err(CausalTensorError::EinSumError(
524+
EinSumValidationError::RankMismatch {
525+
expected: 3,
526+
found: rhs.num_dim(),
527+
},
528+
));
529+
}
522530

523531
let batch_size = lhs.shape()[0];
524532
if batch_size != rhs.shape()[0] {
@@ -546,305 +554,3 @@ where
546554
result_batches.stack(0)
547555
}
548556
}
549-
550-
mod tests {
551-
#![allow(unused_imports)]
552-
553-
use super::*;
554-
use crate::{EinSumOp, utils_tests};
555-
556-
#[test]
557-
fn test_get_binary_operands_success() {
558-
let lhs_tensor = utils_tests::scalar_tensor(1.0);
559-
let rhs_tensor = utils_tests::scalar_tensor(2.0);
560-
let lhs_ast = EinSumOp::tensor_source(lhs_tensor.clone());
561-
let rhs_ast = EinSumOp::tensor_source(rhs_tensor.clone());
562-
let children = vec![lhs_ast, rhs_ast];
563-
564-
let (res_lhs, res_rhs) = CausalTensor::get_binary_operands(&children).unwrap();
565-
assert_eq!(res_lhs, lhs_tensor);
566-
assert_eq!(res_rhs, rhs_tensor);
567-
}
568-
569-
#[test]
570-
fn test_get_binary_operands_invalid_children_count() {
571-
let lhs_tensor = utils_tests::scalar_tensor(1.0);
572-
let lhs_ast = EinSumOp::tensor_source(lhs_tensor.clone());
573-
let children = vec![lhs_ast]; // Only one child
574-
575-
let err = CausalTensor::get_binary_operands(&children).unwrap_err();
576-
assert!(matches!(
577-
err,
578-
CausalTensorError::EinSumError(EinSumValidationError::InvalidNumberOfChildren {
579-
expected: 2,
580-
found: 1
581-
})
582-
));
583-
}
584-
585-
#[test]
586-
fn test_get_unary_operand_success() {
587-
let operand_tensor = utils_tests::scalar_tensor(1.0);
588-
let operand_ast = EinSumOp::tensor_source(operand_tensor.clone());
589-
let children = vec![operand_ast];
590-
591-
let res_operand = CausalTensor::get_unary_operand(&children).unwrap();
592-
assert_eq!(res_operand, operand_tensor);
593-
}
594-
595-
#[test]
596-
fn test_get_unary_operand_invalid_children_count() {
597-
let lhs_tensor = utils_tests::scalar_tensor(1.0);
598-
let rhs_tensor = utils_tests::scalar_tensor(2.0);
599-
let lhs_ast = EinSumOp::tensor_source(lhs_tensor.clone());
600-
let rhs_ast = EinSumOp::tensor_source(rhs_tensor.clone());
601-
let children = vec![lhs_ast, rhs_ast]; // Two children
602-
603-
let err = CausalTensor::get_unary_operand(&children).unwrap_err();
604-
assert!(matches!(
605-
err,
606-
CausalTensorError::EinSumError(EinSumValidationError::InvalidNumberOfChildren {
607-
expected: 1,
608-
found: 2
609-
})
610-
));
611-
}
612-
613-
#[test]
614-
fn test_mat_mul_2d_success() {
615-
let lhs = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
616-
let rhs = utils_tests::matrix_tensor(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
617-
let expected = utils_tests::matrix_tensor(vec![19.0, 22.0, 43.0, 50.0], 2, 2);
618-
619-
let result = CausalTensor::mat_mul_2d(&lhs, &rhs).unwrap();
620-
assert_eq!(result, expected);
621-
}
622-
623-
#[test]
624-
fn test_mat_mul_2d_rank_mismatch() {
625-
let lhs = utils_tests::vector_tensor(vec![1.0, 2.0]);
626-
let rhs = utils_tests::matrix_tensor(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
627-
628-
let err = CausalTensor::mat_mul_2d(&lhs, &rhs).unwrap_err();
629-
assert!(matches!(
630-
err,
631-
CausalTensorError::EinSumError(EinSumValidationError::RankMismatch {
632-
expected: 2,
633-
found: 1
634-
})
635-
));
636-
637-
let lhs = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
638-
let rhs = utils_tests::vector_tensor(vec![5.0, 6.0]);
639-
640-
let err = CausalTensor::mat_mul_2d(&lhs, &rhs).unwrap_err();
641-
assert!(matches!(
642-
err,
643-
CausalTensorError::EinSumError(EinSumValidationError::RankMismatch {
644-
expected: 2,
645-
found: 1
646-
})
647-
));
648-
}
649-
650-
#[test]
651-
fn test_mat_mul_2d_shape_mismatch() {
652-
let lhs = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
653-
let rhs = utils_tests::matrix_tensor(vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 3, 2);
654-
655-
let err = CausalTensor::mat_mul_2d(&lhs, &rhs).unwrap_err();
656-
assert!(matches!(
657-
err,
658-
CausalTensorError::EinSumError(EinSumValidationError::ShapeMismatch { message: _ })
659-
));
660-
}
661-
662-
#[test]
663-
fn test_contract_mat_mul_success() {
664-
let lhs = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
665-
let rhs = utils_tests::matrix_tensor(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
666-
let expected = utils_tests::matrix_tensor(vec![19.0, 22.0, 43.0, 50.0], 2, 2);
667-
668-
let result = CausalTensor::contract(&lhs, &rhs, &[1], &[0]).unwrap();
669-
assert_eq!(result, expected);
670-
}
671-
672-
#[test]
673-
fn test_contract_dot_prod_success() {
674-
let lhs = utils_tests::vector_tensor(vec![1.0, 2.0, 3.0]);
675-
let rhs = utils_tests::vector_tensor(vec![4.0, 5.0, 6.0]);
676-
let expected = utils_tests::scalar_tensor(32.0); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
677-
678-
let result = CausalTensor::contract(&lhs, &rhs, &[0], &[0]).unwrap();
679-
assert_eq!(result, expected);
680-
}
681-
682-
#[test]
683-
fn test_contract_invalid_axes_len() {
684-
let lhs = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
685-
let rhs = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
686-
687-
let err = CausalTensor::contract(&lhs, &rhs, &[0, 1], &[0]).unwrap_err();
688-
assert!(matches!(
689-
err,
690-
CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
691-
message: _
692-
})
693-
));
694-
}
695-
696-
#[test]
697-
fn test_contract_axis_out_of_bounds() {
698-
let lhs = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
699-
let rhs = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
700-
701-
let err = CausalTensor::contract(&lhs, &rhs, &[0], &[2]).unwrap_err(); // RHS axis 2 is out of bounds
702-
assert!(matches!(
703-
err,
704-
CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
705-
message: _
706-
})
707-
));
708-
709-
let err = CausalTensor::contract(&lhs, &rhs, &[2], &[0]).unwrap_err(); // LHS axis 2 is out of bounds
710-
assert!(matches!(
711-
err,
712-
CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
713-
message: _
714-
})
715-
));
716-
}
717-
718-
#[test]
719-
fn test_contract_shape_mismatch() {
720-
let lhs = utils_tests::matrix_tensor(vec![1.0; 6], 2, 3);
721-
let rhs = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
722-
723-
let err = CausalTensor::contract(&lhs, &rhs, &[1], &[0]).unwrap_err(); // LHS dim 1 (3) != RHS dim 0 (2)
724-
assert!(matches!(
725-
err,
726-
CausalTensorError::EinSumError(EinSumValidationError::ShapeMismatch { message: _ })
727-
));
728-
}
729-
730-
#[test]
731-
fn test_element_wise_mul_success() {
732-
let lhs = utils_tests::vector_tensor(vec![1.0, 2.0, 3.0]);
733-
let rhs = utils_tests::vector_tensor(vec![4.0, 5.0, 6.0]);
734-
let expected = utils_tests::vector_tensor(vec![4.0, 10.0, 18.0]);
735-
736-
let result = CausalTensor::element_wise_mul(&lhs, &rhs).unwrap();
737-
assert_eq!(result, expected);
738-
}
739-
740-
#[test]
741-
fn test_trace_success_matrix() {
742-
let operand = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
743-
let expected = utils_tests::scalar_tensor(5.0); // 1.0 + 4.0
744-
745-
let result = CausalTensor::trace(&operand, 0, 1).unwrap();
746-
assert_eq!(result, expected);
747-
}
748-
749-
#[test]
750-
fn test_trace_success_3d_tensor() {
751-
let operand =
752-
CausalTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
753-
// Trace over axes 1 and 2 (matrices within the batch)
754-
// Batch 0: [[1,2],[3,4]] -> 1+4 = 5
755-
// Batch 1: [[5,6],[7,8]] -> 5+8 = 13
756-
let expected = utils_tests::vector_tensor(vec![5.0, 13.0]);
757-
758-
let result = CausalTensor::trace(&operand, 1, 2).unwrap();
759-
assert_eq!(result, expected);
760-
}
761-
762-
#[test]
763-
fn test_trace_invalid_axes_out_of_bounds() {
764-
let operand = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
765-
let err = CausalTensor::trace(&operand, 0, 2).unwrap_err();
766-
assert!(matches!(
767-
err,
768-
CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
769-
message: _
770-
})
771-
));
772-
}
773-
774-
#[test]
775-
fn test_trace_invalid_axes_identical() {
776-
let operand = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
777-
let err = CausalTensor::trace(&operand, 0, 0).unwrap_err();
778-
assert!(matches!(
779-
err,
780-
CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
781-
message: _
782-
})
783-
));
784-
}
785-
786-
#[test]
787-
fn test_trace_shape_mismatch() {
788-
let operand = CausalTensor::new(vec![1.0; 6], vec![2, 3]).unwrap(); // 2x3 matrix
789-
let err = CausalTensor::trace(&operand, 0, 1).unwrap_err();
790-
assert!(matches!(
791-
err,
792-
CausalTensorError::EinSumError(EinSumValidationError::ShapeMismatch { message: _ })
793-
));
794-
}
795-
796-
#[test]
797-
fn test_diagonal_success_matrix() {
798-
let operand = utils_tests::matrix_tensor(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
799-
let expected = utils_tests::vector_tensor(vec![1.0, 4.0]);
800-
801-
let result = CausalTensor::diagonal(&operand, 0, 1).unwrap();
802-
assert_eq!(result, expected);
803-
}
804-
#[test]
805-
fn test_diagonal_success_3d_tensor() {
806-
let operand =
807-
CausalTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
808-
// Extract diagonal over axes 1 and 2 (matrices within the batch)
809-
// Batch 0: [[1,2],[3,4]] -> [1,4]
810-
// Batch 1: [[5,6],[7,8]] -> [5,8]
811-
let expected = CausalTensor::new(vec![1.0, 4.0, 5.0, 8.0], vec![2, 2]).unwrap();
812-
813-
let result = CausalTensor::diagonal(&operand, 1, 2).unwrap();
814-
assert_eq!(result, expected);
815-
}
816-
817-
#[test]
818-
fn test_diagonal_invalid_axes_out_of_bounds() {
819-
let operand = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
820-
let err = CausalTensor::diagonal(&operand, 0, 2).unwrap_err();
821-
assert!(matches!(
822-
err,
823-
CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
824-
message: _
825-
})
826-
));
827-
}
828-
829-
#[test]
830-
fn test_diagonal_invalid_axes_identical() {
831-
let operand = utils_tests::matrix_tensor(vec![1.0; 4], 2, 2);
832-
let err = CausalTensor::diagonal(&operand, 0, 0).unwrap_err();
833-
assert!(matches!(
834-
err,
835-
CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
836-
message: _
837-
})
838-
));
839-
}
840-
841-
#[test]
842-
fn test_diagonal_shape_mismatch() {
843-
let operand = CausalTensor::new(vec![1.0; 6], vec![2, 3]).unwrap(); // 2x3 matrix
844-
let err = CausalTensor::diagonal(&operand, 0, 1).unwrap_err();
845-
assert!(matches!(
846-
err,
847-
CausalTensorError::EinSumError(EinSumValidationError::ShapeMismatch { message: _ })
848-
));
849-
}
850-
}

0 commit comments

Comments
 (0)