@@ -511,14 +511,22 @@ where
511511 lhs : CausalTensor < T > ,
512512 rhs : CausalTensor < T > ,
513513 ) -> Result < CausalTensor < T > , CausalTensorError > {
514- if lhs. num_dim ( ) < 3 || rhs . num_dim ( ) < 3 {
514+ if lhs. num_dim ( ) < 3 {
515515 return Err ( CausalTensorError :: EinSumError (
516516 EinSumValidationError :: RankMismatch {
517- expected : 3 , // At least 3 dimensions for batch matmul (batch, rows, cols)
517+ expected : 3 ,
518518 found : lhs. num_dim ( ) ,
519519 } ,
520520 ) ) ;
521521 }
522+ if rhs. num_dim ( ) < 3 {
523+ return Err ( CausalTensorError :: EinSumError (
524+ EinSumValidationError :: RankMismatch {
525+ expected : 3 ,
526+ found : rhs. num_dim ( ) ,
527+ } ,
528+ ) ) ;
529+ }
522530
523531 let batch_size = lhs. shape ( ) [ 0 ] ;
524532 if batch_size != rhs. shape ( ) [ 0 ] {
@@ -546,305 +554,3 @@ where
546554 result_batches. stack ( 0 )
547555 }
548556}
549-
550- mod tests {
551- #![ allow( unused_imports) ]
552-
553- use super :: * ;
554- use crate :: { EinSumOp , utils_tests} ;
555-
556- #[ test]
557- fn test_get_binary_operands_success ( ) {
558- let lhs_tensor = utils_tests:: scalar_tensor ( 1.0 ) ;
559- let rhs_tensor = utils_tests:: scalar_tensor ( 2.0 ) ;
560- let lhs_ast = EinSumOp :: tensor_source ( lhs_tensor. clone ( ) ) ;
561- let rhs_ast = EinSumOp :: tensor_source ( rhs_tensor. clone ( ) ) ;
562- let children = vec ! [ lhs_ast, rhs_ast] ;
563-
564- let ( res_lhs, res_rhs) = CausalTensor :: get_binary_operands ( & children) . unwrap ( ) ;
565- assert_eq ! ( res_lhs, lhs_tensor) ;
566- assert_eq ! ( res_rhs, rhs_tensor) ;
567- }
568-
569- #[ test]
570- fn test_get_binary_operands_invalid_children_count ( ) {
571- let lhs_tensor = utils_tests:: scalar_tensor ( 1.0 ) ;
572- let lhs_ast = EinSumOp :: tensor_source ( lhs_tensor. clone ( ) ) ;
573- let children = vec ! [ lhs_ast] ; // Only one child
574-
575- let err = CausalTensor :: get_binary_operands ( & children) . unwrap_err ( ) ;
576- assert ! ( matches!(
577- err,
578- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidNumberOfChildren {
579- expected: 2 ,
580- found: 1
581- } )
582- ) ) ;
583- }
584-
585- #[ test]
586- fn test_get_unary_operand_success ( ) {
587- let operand_tensor = utils_tests:: scalar_tensor ( 1.0 ) ;
588- let operand_ast = EinSumOp :: tensor_source ( operand_tensor. clone ( ) ) ;
589- let children = vec ! [ operand_ast] ;
590-
591- let res_operand = CausalTensor :: get_unary_operand ( & children) . unwrap ( ) ;
592- assert_eq ! ( res_operand, operand_tensor) ;
593- }
594-
595- #[ test]
596- fn test_get_unary_operand_invalid_children_count ( ) {
597- let lhs_tensor = utils_tests:: scalar_tensor ( 1.0 ) ;
598- let rhs_tensor = utils_tests:: scalar_tensor ( 2.0 ) ;
599- let lhs_ast = EinSumOp :: tensor_source ( lhs_tensor. clone ( ) ) ;
600- let rhs_ast = EinSumOp :: tensor_source ( rhs_tensor. clone ( ) ) ;
601- let children = vec ! [ lhs_ast, rhs_ast] ; // Two children
602-
603- let err = CausalTensor :: get_unary_operand ( & children) . unwrap_err ( ) ;
604- assert ! ( matches!(
605- err,
606- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidNumberOfChildren {
607- expected: 1 ,
608- found: 2
609- } )
610- ) ) ;
611- }
612-
613- #[ test]
614- fn test_mat_mul_2d_success ( ) {
615- let lhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] , 2 , 2 ) ;
616- let rhs = utils_tests:: matrix_tensor ( vec ! [ 5.0 , 6.0 , 7.0 , 8.0 ] , 2 , 2 ) ;
617- let expected = utils_tests:: matrix_tensor ( vec ! [ 19.0 , 22.0 , 43.0 , 50.0 ] , 2 , 2 ) ;
618-
619- let result = CausalTensor :: mat_mul_2d ( & lhs, & rhs) . unwrap ( ) ;
620- assert_eq ! ( result, expected) ;
621- }
622-
623- #[ test]
624- fn test_mat_mul_2d_rank_mismatch ( ) {
625- let lhs = utils_tests:: vector_tensor ( vec ! [ 1.0 , 2.0 ] ) ;
626- let rhs = utils_tests:: matrix_tensor ( vec ! [ 5.0 , 6.0 , 7.0 , 8.0 ] , 2 , 2 ) ;
627-
628- let err = CausalTensor :: mat_mul_2d ( & lhs, & rhs) . unwrap_err ( ) ;
629- assert ! ( matches!(
630- err,
631- CausalTensorError :: EinSumError ( EinSumValidationError :: RankMismatch {
632- expected: 2 ,
633- found: 1
634- } )
635- ) ) ;
636-
637- let lhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] , 2 , 2 ) ;
638- let rhs = utils_tests:: vector_tensor ( vec ! [ 5.0 , 6.0 ] ) ;
639-
640- let err = CausalTensor :: mat_mul_2d ( & lhs, & rhs) . unwrap_err ( ) ;
641- assert ! ( matches!(
642- err,
643- CausalTensorError :: EinSumError ( EinSumValidationError :: RankMismatch {
644- expected: 2 ,
645- found: 1
646- } )
647- ) ) ;
648- }
649-
650- #[ test]
651- fn test_mat_mul_2d_shape_mismatch ( ) {
652- let lhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] , 2 , 2 ) ;
653- let rhs = utils_tests:: matrix_tensor ( vec ! [ 5.0 , 6.0 , 7.0 , 8.0 , 9.0 , 10.0 ] , 3 , 2 ) ;
654-
655- let err = CausalTensor :: mat_mul_2d ( & lhs, & rhs) . unwrap_err ( ) ;
656- assert ! ( matches!(
657- err,
658- CausalTensorError :: EinSumError ( EinSumValidationError :: ShapeMismatch { message: _ } )
659- ) ) ;
660- }
661-
662- #[ test]
663- fn test_contract_mat_mul_success ( ) {
664- let lhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] , 2 , 2 ) ;
665- let rhs = utils_tests:: matrix_tensor ( vec ! [ 5.0 , 6.0 , 7.0 , 8.0 ] , 2 , 2 ) ;
666- let expected = utils_tests:: matrix_tensor ( vec ! [ 19.0 , 22.0 , 43.0 , 50.0 ] , 2 , 2 ) ;
667-
668- let result = CausalTensor :: contract ( & lhs, & rhs, & [ 1 ] , & [ 0 ] ) . unwrap ( ) ;
669- assert_eq ! ( result, expected) ;
670- }
671-
672- #[ test]
673- fn test_contract_dot_prod_success ( ) {
674- let lhs = utils_tests:: vector_tensor ( vec ! [ 1.0 , 2.0 , 3.0 ] ) ;
675- let rhs = utils_tests:: vector_tensor ( vec ! [ 4.0 , 5.0 , 6.0 ] ) ;
676- let expected = utils_tests:: scalar_tensor ( 32.0 ) ; // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
677-
678- let result = CausalTensor :: contract ( & lhs, & rhs, & [ 0 ] , & [ 0 ] ) . unwrap ( ) ;
679- assert_eq ! ( result, expected) ;
680- }
681-
682- #[ test]
683- fn test_contract_invalid_axes_len ( ) {
684- let lhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
685- let rhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
686-
687- let err = CausalTensor :: contract ( & lhs, & rhs, & [ 0 , 1 ] , & [ 0 ] ) . unwrap_err ( ) ;
688- assert ! ( matches!(
689- err,
690- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidAxesSpecification {
691- message: _
692- } )
693- ) ) ;
694- }
695-
696- #[ test]
697- fn test_contract_axis_out_of_bounds ( ) {
698- let lhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
699- let rhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
700-
701- let err = CausalTensor :: contract ( & lhs, & rhs, & [ 0 ] , & [ 2 ] ) . unwrap_err ( ) ; // RHS axis 2 is out of bounds
702- assert ! ( matches!(
703- err,
704- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidAxesSpecification {
705- message: _
706- } )
707- ) ) ;
708-
709- let err = CausalTensor :: contract ( & lhs, & rhs, & [ 2 ] , & [ 0 ] ) . unwrap_err ( ) ; // LHS axis 2 is out of bounds
710- assert ! ( matches!(
711- err,
712- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidAxesSpecification {
713- message: _
714- } )
715- ) ) ;
716- }
717-
718- #[ test]
719- fn test_contract_shape_mismatch ( ) {
720- let lhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 6 ] , 2 , 3 ) ;
721- let rhs = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
722-
723- let err = CausalTensor :: contract ( & lhs, & rhs, & [ 1 ] , & [ 0 ] ) . unwrap_err ( ) ; // LHS dim 1 (3) != RHS dim 0 (2)
724- assert ! ( matches!(
725- err,
726- CausalTensorError :: EinSumError ( EinSumValidationError :: ShapeMismatch { message: _ } )
727- ) ) ;
728- }
729-
730- #[ test]
731- fn test_element_wise_mul_success ( ) {
732- let lhs = utils_tests:: vector_tensor ( vec ! [ 1.0 , 2.0 , 3.0 ] ) ;
733- let rhs = utils_tests:: vector_tensor ( vec ! [ 4.0 , 5.0 , 6.0 ] ) ;
734- let expected = utils_tests:: vector_tensor ( vec ! [ 4.0 , 10.0 , 18.0 ] ) ;
735-
736- let result = CausalTensor :: element_wise_mul ( & lhs, & rhs) . unwrap ( ) ;
737- assert_eq ! ( result, expected) ;
738- }
739-
740- #[ test]
741- fn test_trace_success_matrix ( ) {
742- let operand = utils_tests:: matrix_tensor ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] , 2 , 2 ) ;
743- let expected = utils_tests:: scalar_tensor ( 5.0 ) ; // 1.0 + 4.0
744-
745- let result = CausalTensor :: trace ( & operand, 0 , 1 ) . unwrap ( ) ;
746- assert_eq ! ( result, expected) ;
747- }
748-
749- #[ test]
750- fn test_trace_success_3d_tensor ( ) {
751- let operand =
752- CausalTensor :: new ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ] , vec ! [ 2 , 2 , 2 ] ) . unwrap ( ) ;
753- // Trace over axes 1 and 2 (matrices within the batch)
754- // Batch 0: [[1,2],[3,4]] -> 1+4 = 5
755- // Batch 1: [[5,6],[7,8]] -> 5+8 = 13
756- let expected = utils_tests:: vector_tensor ( vec ! [ 5.0 , 13.0 ] ) ;
757-
758- let result = CausalTensor :: trace ( & operand, 1 , 2 ) . unwrap ( ) ;
759- assert_eq ! ( result, expected) ;
760- }
761-
762- #[ test]
763- fn test_trace_invalid_axes_out_of_bounds ( ) {
764- let operand = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
765- let err = CausalTensor :: trace ( & operand, 0 , 2 ) . unwrap_err ( ) ;
766- assert ! ( matches!(
767- err,
768- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidAxesSpecification {
769- message: _
770- } )
771- ) ) ;
772- }
773-
774- #[ test]
775- fn test_trace_invalid_axes_identical ( ) {
776- let operand = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
777- let err = CausalTensor :: trace ( & operand, 0 , 0 ) . unwrap_err ( ) ;
778- assert ! ( matches!(
779- err,
780- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidAxesSpecification {
781- message: _
782- } )
783- ) ) ;
784- }
785-
786- #[ test]
787- fn test_trace_shape_mismatch ( ) {
788- let operand = CausalTensor :: new ( vec ! [ 1.0 ; 6 ] , vec ! [ 2 , 3 ] ) . unwrap ( ) ; // 2x3 matrix
789- let err = CausalTensor :: trace ( & operand, 0 , 1 ) . unwrap_err ( ) ;
790- assert ! ( matches!(
791- err,
792- CausalTensorError :: EinSumError ( EinSumValidationError :: ShapeMismatch { message: _ } )
793- ) ) ;
794- }
795-
796- #[ test]
797- fn test_diagonal_success_matrix ( ) {
798- let operand = utils_tests:: matrix_tensor ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] , 2 , 2 ) ;
799- let expected = utils_tests:: vector_tensor ( vec ! [ 1.0 , 4.0 ] ) ;
800-
801- let result = CausalTensor :: diagonal ( & operand, 0 , 1 ) . unwrap ( ) ;
802- assert_eq ! ( result, expected) ;
803- }
804- #[ test]
805- fn test_diagonal_success_3d_tensor ( ) {
806- let operand =
807- CausalTensor :: new ( vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ] , vec ! [ 2 , 2 , 2 ] ) . unwrap ( ) ;
808- // Extract diagonal over axes 1 and 2 (matrices within the batch)
809- // Batch 0: [[1,2],[3,4]] -> [1,4]
810- // Batch 1: [[5,6],[7,8]] -> [5,8]
811- let expected = CausalTensor :: new ( vec ! [ 1.0 , 4.0 , 5.0 , 8.0 ] , vec ! [ 2 , 2 ] ) . unwrap ( ) ;
812-
813- let result = CausalTensor :: diagonal ( & operand, 1 , 2 ) . unwrap ( ) ;
814- assert_eq ! ( result, expected) ;
815- }
816-
817- #[ test]
818- fn test_diagonal_invalid_axes_out_of_bounds ( ) {
819- let operand = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
820- let err = CausalTensor :: diagonal ( & operand, 0 , 2 ) . unwrap_err ( ) ;
821- assert ! ( matches!(
822- err,
823- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidAxesSpecification {
824- message: _
825- } )
826- ) ) ;
827- }
828-
829- #[ test]
830- fn test_diagonal_invalid_axes_identical ( ) {
831- let operand = utils_tests:: matrix_tensor ( vec ! [ 1.0 ; 4 ] , 2 , 2 ) ;
832- let err = CausalTensor :: diagonal ( & operand, 0 , 0 ) . unwrap_err ( ) ;
833- assert ! ( matches!(
834- err,
835- CausalTensorError :: EinSumError ( EinSumValidationError :: InvalidAxesSpecification {
836- message: _
837- } )
838- ) ) ;
839- }
840-
841- #[ test]
842- fn test_diagonal_shape_mismatch ( ) {
843- let operand = CausalTensor :: new ( vec ! [ 1.0 ; 6 ] , vec ! [ 2 , 3 ] ) . unwrap ( ) ; // 2x3 matrix
844- let err = CausalTensor :: diagonal ( & operand, 0 , 1 ) . unwrap_err ( ) ;
845- assert ! ( matches!(
846- err,
847- CausalTensorError :: EinSumError ( EinSumValidationError :: ShapeMismatch { message: _ } )
848- ) ) ;
849- }
850- }
0 commit comments