11from __future__ import annotations
22
3+ from typing import TYPE_CHECKING , Optional , cast
4+
5+ from rich .tree import Tree
6+
7+ from codeflash .cli_cmds .console import DEBUG_MODE , logger
8+
9+ if TYPE_CHECKING :
10+ from collections .abc import Iterator
311import enum
412import json
513import re
14+ import sys
615from collections .abc import Collection , Iterator
716from enum import Enum , IntEnum
817from pathlib import Path
918from re import Pattern
10- from typing import Annotated , Any , Optional , Union
19+ from typing import Annotated , Any , Optional , Union , cast
1120
1221import sentry_sdk
1322from coverage .exceptions import NoDataError
2332 generate_candidates ,
2433)
2534from codeflash .code_utils .env_utils import is_end_to_end
26- from codeflash .verification .test_results import TestResults , TestType
35+ from codeflash .verification .comparator import comparator
2736
2837# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
2938# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
@@ -511,3 +520,236 @@ class FunctionCoverage:
511520class TestingMode (enum .Enum ):
512521 BEHAVIOR = "behavior"
513522 PERFORMANCE = "performance"
523+
524+
525+ class VerificationType (str , Enum ):
526+ FUNCTION_CALL = (
527+ "function_call" # Correctness verification for a test function, checks input values and output values)
528+ )
529+ INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
530+ INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
531+
532+
533+ class TestType (Enum ):
534+ EXISTING_UNIT_TEST = 1
535+ INSPIRED_REGRESSION = 2
536+ GENERATED_REGRESSION = 3
537+ REPLAY_TEST = 4
538+ CONCOLIC_COVERAGE_TEST = 5
539+ INIT_STATE_TEST = 6
540+
541+ def to_name (self ) -> str :
542+ if self is TestType .INIT_STATE_TEST :
543+ return ""
544+ names = {
545+ TestType .EXISTING_UNIT_TEST : "⚙️ Existing Unit Tests" ,
546+ TestType .INSPIRED_REGRESSION : "🎨 Inspired Regression Tests" ,
547+ TestType .GENERATED_REGRESSION : "🌀 Generated Regression Tests" ,
548+ TestType .REPLAY_TEST : "⏪ Replay Tests" ,
549+ TestType .CONCOLIC_COVERAGE_TEST : "🔎 Concolic Coverage Tests" ,
550+ }
551+ return names [self ]
552+
553+
554+ @dataclass (frozen = True )
555+ class InvocationId :
556+ test_module_path : str # The fully qualified name of the test module
557+ test_class_name : Optional [str ] # The name of the class where the test is defined
558+ test_function_name : Optional [str ] # The name of the test_function. Does not include the components of the file_name
559+ function_getting_tested : str
560+ iteration_id : Optional [str ]
561+
562+ # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id
563+ def id (self ) -> str :
564+ class_prefix = f"{ self .test_class_name } ." if self .test_class_name else ""
565+ return (
566+ f"{ self .test_module_path } :{ class_prefix } { self .test_function_name } :"
567+ f"{ self .function_getting_tested } :{ self .iteration_id } "
568+ )
569+
570+ @staticmethod
571+ def from_str_id (string_id : str , iteration_id : Optional [str ] = None ) -> InvocationId :
572+ components = string_id .split (":" )
573+ assert len (components ) == 4
574+ second_components = components [1 ].split ("." )
575+ if len (second_components ) == 1 :
576+ test_class_name = None
577+ test_function_name = second_components [0 ]
578+ else :
579+ test_class_name = second_components [0 ]
580+ test_function_name = second_components [1 ]
581+ return InvocationId (
582+ test_module_path = components [0 ],
583+ test_class_name = test_class_name ,
584+ test_function_name = test_function_name ,
585+ function_getting_tested = components [2 ],
586+ iteration_id = iteration_id if iteration_id else components [3 ],
587+ )
588+
589+
590+ @dataclass (frozen = True )
591+ class FunctionTestInvocation :
592+ loop_index : int # The loop index of the function invocation, starts at 1
593+ id : InvocationId # The fully qualified name of the function invocation (id)
594+ file_name : Path # The file where the test is defined
595+ did_pass : bool # Whether the test this function invocation was part of, passed or failed
596+ runtime : Optional [int ] # Time in nanoseconds
597+ test_framework : str # unittest or pytest
598+ test_type : TestType
599+ return_value : Optional [object ] # The return value of the function invocation
600+ timed_out : Optional [bool ]
601+ verification_type : Optional [str ] = VerificationType .FUNCTION_CALL
602+ stdout : Optional [str ] = None
603+
604+ @property
605+ def unique_invocation_loop_id (self ) -> str :
606+ return f"{ self .loop_index } :{ self .id .id ()} "
607+
608+
609+ class TestResults (BaseModel ):
610+ # don't modify these directly, use the add method
611+ # also we don't support deletion of test results elements - caution is advised
612+ test_results : list [FunctionTestInvocation ] = []
613+ test_result_idx : dict [str , int ] = {}
614+
615+ def add (self , function_test_invocation : FunctionTestInvocation ) -> None :
616+ unique_id = function_test_invocation .unique_invocation_loop_id
617+ if unique_id in self .test_result_idx :
618+ if DEBUG_MODE :
619+ logger .warning (f"Test result with id { unique_id } already exists. SKIPPING" )
620+ return
621+ self .test_result_idx [unique_id ] = len (self .test_results )
622+ self .test_results .append (function_test_invocation )
623+
624+ def merge (self , other : TestResults ) -> None :
625+ original_len = len (self .test_results )
626+ self .test_results .extend (other .test_results )
627+ for k , v in other .test_result_idx .items ():
628+ if k in self .test_result_idx :
629+ msg = f"Test result with id { k } already exists."
630+ raise ValueError (msg )
631+ self .test_result_idx [k ] = v + original_len
632+
633+ def get_by_unique_invocation_loop_id (self , unique_invocation_loop_id : str ) -> FunctionTestInvocation | None :
634+ try :
635+ return self .test_results [self .test_result_idx [unique_invocation_loop_id ]]
636+ except (IndexError , KeyError ):
637+ return None
638+
639+ def get_all_ids (self ) -> set [InvocationId ]:
640+ return {test_result .id for test_result in self .test_results }
641+
642+ def get_all_unique_invocation_loop_ids (self ) -> set [str ]:
643+ return {test_result .unique_invocation_loop_id for test_result in self .test_results }
644+
645+ def number_of_loops (self ) -> int :
646+ if not self .test_results :
647+ return 0
648+ return max (test_result .loop_index for test_result in self .test_results )
649+
650+ def get_test_pass_fail_report_by_type (self ) -> dict [TestType , dict [str , int ]]:
651+ report = {}
652+ for test_type in TestType :
653+ report [test_type ] = {"passed" : 0 , "failed" : 0 }
654+ for test_result in self .test_results :
655+ if test_result .loop_index == 1 :
656+ if test_result .did_pass :
657+ report [test_result .test_type ]["passed" ] += 1
658+ else :
659+ report [test_result .test_type ]["failed" ] += 1
660+ return report
661+
662+ @staticmethod
663+ def report_to_string (report : dict [TestType , dict [str , int ]]) -> str :
664+ return " " .join (
665+ [
666+ f"{ test_type .to_name ()} - (Passed: { report [test_type ]['passed' ]} , Failed: { report [test_type ]['failed' ]} )"
667+ for test_type in TestType
668+ ]
669+ )
670+
671+ @staticmethod
672+ def report_to_tree (report : dict [TestType , dict [str , int ]], title : str ) -> Tree :
673+ tree = Tree (title )
674+ for test_type in TestType :
675+ if test_type is TestType .INIT_STATE_TEST :
676+ continue
677+ tree .add (
678+ f"{ test_type .to_name ()} - Passed: { report [test_type ]['passed' ]} , Failed: { report [test_type ]['failed' ]} "
679+ )
680+ return tree
681+
682+ def usable_runtime_data_by_test_case (self ) -> dict [InvocationId , list [int ]]:
683+ for result in self .test_results :
684+ if result .did_pass and not result .runtime :
685+ msg = (
686+ f"Ignoring test case that passed but had no runtime -> { result .id } , "
687+ f"Loop # { result .loop_index } , Test Type: { result .test_type } , "
688+ f"Verification Type: { result .verification_type } "
689+ )
690+ logger .debug (msg )
691+
692+ usable_runtimes = [
693+ (result .id , result .runtime ) for result in self .test_results if result .did_pass and result .runtime
694+ ]
695+ return {
696+ usable_id : [runtime [1 ] for runtime in usable_runtimes if runtime [0 ] == usable_id ]
697+ for usable_id in {runtime [0 ] for runtime in usable_runtimes }
698+ }
699+
700+ def total_passed_runtime (self ) -> int :
701+ """Calculate the sum of runtimes of all test cases that passed.
702+
703+ A testcase runtime is the minimum value of all looped execution runtimes.
704+
705+ :return: The runtime in nanoseconds.
706+ """
707+ return sum (
708+ [min (usable_runtime_data ) for _ , usable_runtime_data in self .usable_runtime_data_by_test_case ().items ()]
709+ )
710+
711+ def __iter__ (self ) -> Iterator [FunctionTestInvocation ]:
712+ return iter (self .test_results )
713+
714+ def __len__ (self ) -> int :
715+ return len (self .test_results )
716+
717+ def __getitem__ (self , index : int ) -> FunctionTestInvocation :
718+ return self .test_results [index ]
719+
720+ def __setitem__ (self , index : int , value : FunctionTestInvocation ) -> None :
721+ self .test_results [index ] = value
722+
723+ def __contains__ (self , value : FunctionTestInvocation ) -> bool :
724+ return value in self .test_results
725+
726+ def __bool__ (self ) -> bool :
727+ return bool (self .test_results )
728+
729+ def __eq__ (self , other : object ) -> bool :
730+ # Unordered comparison
731+ if type (self ) is not type (other ):
732+ return False
733+ if len (self ) != len (other ):
734+ return False
735+ original_recursion_limit = sys .getrecursionlimit ()
736+ cast (TestResults , other )
737+ for test_result in self :
738+ other_test_result = other .get_by_unique_invocation_loop_id (test_result .unique_invocation_loop_id )
739+ if other_test_result is None :
740+ return False
741+
742+ if original_recursion_limit < 5000 :
743+ sys .setrecursionlimit (5000 )
744+ if (
745+ test_result .file_name != other_test_result .file_name
746+ or test_result .did_pass != other_test_result .did_pass
747+ or test_result .runtime != other_test_result .runtime
748+ or test_result .test_framework != other_test_result .test_framework
749+ or test_result .test_type != other_test_result .test_type
750+ or not comparator (test_result .return_value , other_test_result .return_value )
751+ ):
752+ sys .setrecursionlimit (original_recursion_limit )
753+ return False
754+ sys .setrecursionlimit (original_recursion_limit )
755+ return True
0 commit comments