Skip to content

Commit 781886b

Browse files
committed
modified imports
1 parent 3aae8c2 commit 781886b

21 files changed

+498
-265
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from codeflash.cli_cmds.console import logger
1010
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
1111
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
12-
from codeflash.models.models import FunctionParent, TestingMode
13-
from codeflash.verification.test_results import VerificationType
12+
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
1413

1514
if TYPE_CHECKING:
1615
from collections.abc import Iterable

codeflash/discovery/discover_unit_tests.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from codeflash.cli_cmds.console import console, logger
1717
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
1818
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
19-
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile
20-
from codeflash.verification.test_results import TestType
19+
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
2120

2221
if TYPE_CHECKING:
2322
from codeflash.verification.verification_utils import TestConfig

codeflash/github/PrComment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pydantic.dataclasses import dataclass
55

66
from codeflash.code_utils.time_utils import humanize_runtime
7-
from codeflash.verification.test_results import TestResults
7+
from codeflash.models.models import TestResults
88

99

1010
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})

codeflash/models/models.py

Lines changed: 244 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING, Optional, cast
4+
5+
from rich.tree import Tree
6+
7+
from codeflash.cli_cmds.console import DEBUG_MODE, logger
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Iterator
311
import enum
412
import json
513
import re
14+
import sys
615
from collections.abc import Collection, Iterator
716
from enum import Enum, IntEnum
817
from pathlib import Path
918
from re import Pattern
10-
from typing import Annotated, Any, Optional, Union
19+
from typing import Annotated, Any, Optional, Union, cast
1120

1221
import sentry_sdk
1322
from coverage.exceptions import NoDataError
@@ -23,7 +32,7 @@
2332
generate_candidates,
2433
)
2534
from codeflash.code_utils.env_utils import is_end_to_end
26-
from codeflash.verification.test_results import TestResults, TestType
35+
from codeflash.verification.comparator import comparator
2736

2837
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
2938
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
@@ -511,3 +520,236 @@ class FunctionCoverage:
511520
class TestingMode(enum.Enum):
512521
BEHAVIOR = "behavior"
513522
PERFORMANCE = "performance"
523+
524+
525+
class VerificationType(str, Enum):
526+
FUNCTION_CALL = (
527+
"function_call" # Correctness verification for a test function, checks input values and output values)
528+
)
529+
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
530+
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
531+
532+
533+
class TestType(Enum):
534+
EXISTING_UNIT_TEST = 1
535+
INSPIRED_REGRESSION = 2
536+
GENERATED_REGRESSION = 3
537+
REPLAY_TEST = 4
538+
CONCOLIC_COVERAGE_TEST = 5
539+
INIT_STATE_TEST = 6
540+
541+
def to_name(self) -> str:
542+
if self is TestType.INIT_STATE_TEST:
543+
return ""
544+
names = {
545+
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
546+
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
547+
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
548+
TestType.REPLAY_TEST: "⏪ Replay Tests",
549+
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
550+
}
551+
return names[self]
552+
553+
554+
@dataclass(frozen=True)
555+
class InvocationId:
556+
test_module_path: str # The fully qualified name of the test module
557+
test_class_name: Optional[str] # The name of the class where the test is defined
558+
test_function_name: Optional[str] # The name of the test_function. Does not include the components of the file_name
559+
function_getting_tested: str
560+
iteration_id: Optional[str]
561+
562+
# test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id
563+
def id(self) -> str:
564+
class_prefix = f"{self.test_class_name}." if self.test_class_name else ""
565+
return (
566+
f"{self.test_module_path}:{class_prefix}{self.test_function_name}:"
567+
f"{self.function_getting_tested}:{self.iteration_id}"
568+
)
569+
570+
@staticmethod
571+
def from_str_id(string_id: str, iteration_id: Optional[str] = None) -> InvocationId:
572+
components = string_id.split(":")
573+
assert len(components) == 4
574+
second_components = components[1].split(".")
575+
if len(second_components) == 1:
576+
test_class_name = None
577+
test_function_name = second_components[0]
578+
else:
579+
test_class_name = second_components[0]
580+
test_function_name = second_components[1]
581+
return InvocationId(
582+
test_module_path=components[0],
583+
test_class_name=test_class_name,
584+
test_function_name=test_function_name,
585+
function_getting_tested=components[2],
586+
iteration_id=iteration_id if iteration_id else components[3],
587+
)
588+
589+
590+
@dataclass(frozen=True)
591+
class FunctionTestInvocation:
592+
loop_index: int # The loop index of the function invocation, starts at 1
593+
id: InvocationId # The fully qualified name of the function invocation (id)
594+
file_name: Path # The file where the test is defined
595+
did_pass: bool # Whether the test this function invocation was part of, passed or failed
596+
runtime: Optional[int] # Time in nanoseconds
597+
test_framework: str # unittest or pytest
598+
test_type: TestType
599+
return_value: Optional[object] # The return value of the function invocation
600+
timed_out: Optional[bool]
601+
verification_type: Optional[str] = VerificationType.FUNCTION_CALL
602+
stdout: Optional[str] = None
603+
604+
@property
605+
def unique_invocation_loop_id(self) -> str:
606+
return f"{self.loop_index}:{self.id.id()}"
607+
608+
609+
class TestResults(BaseModel):
610+
# don't modify these directly, use the add method
611+
# also we don't support deletion of test results elements - caution is advised
612+
test_results: list[FunctionTestInvocation] = []
613+
test_result_idx: dict[str, int] = {}
614+
615+
def add(self, function_test_invocation: FunctionTestInvocation) -> None:
616+
unique_id = function_test_invocation.unique_invocation_loop_id
617+
if unique_id in self.test_result_idx:
618+
if DEBUG_MODE:
619+
logger.warning(f"Test result with id {unique_id} already exists. SKIPPING")
620+
return
621+
self.test_result_idx[unique_id] = len(self.test_results)
622+
self.test_results.append(function_test_invocation)
623+
624+
def merge(self, other: TestResults) -> None:
625+
original_len = len(self.test_results)
626+
self.test_results.extend(other.test_results)
627+
for k, v in other.test_result_idx.items():
628+
if k in self.test_result_idx:
629+
msg = f"Test result with id {k} already exists."
630+
raise ValueError(msg)
631+
self.test_result_idx[k] = v + original_len
632+
633+
def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None:
634+
try:
635+
return self.test_results[self.test_result_idx[unique_invocation_loop_id]]
636+
except (IndexError, KeyError):
637+
return None
638+
639+
def get_all_ids(self) -> set[InvocationId]:
640+
return {test_result.id for test_result in self.test_results}
641+
642+
def get_all_unique_invocation_loop_ids(self) -> set[str]:
643+
return {test_result.unique_invocation_loop_id for test_result in self.test_results}
644+
645+
def number_of_loops(self) -> int:
646+
if not self.test_results:
647+
return 0
648+
return max(test_result.loop_index for test_result in self.test_results)
649+
650+
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
651+
report = {}
652+
for test_type in TestType:
653+
report[test_type] = {"passed": 0, "failed": 0}
654+
for test_result in self.test_results:
655+
if test_result.loop_index == 1:
656+
if test_result.did_pass:
657+
report[test_result.test_type]["passed"] += 1
658+
else:
659+
report[test_result.test_type]["failed"] += 1
660+
return report
661+
662+
@staticmethod
663+
def report_to_string(report: dict[TestType, dict[str, int]]) -> str:
664+
return " ".join(
665+
[
666+
f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})"
667+
for test_type in TestType
668+
]
669+
)
670+
671+
@staticmethod
672+
def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
673+
tree = Tree(title)
674+
for test_type in TestType:
675+
if test_type is TestType.INIT_STATE_TEST:
676+
continue
677+
tree.add(
678+
f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}"
679+
)
680+
return tree
681+
682+
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
683+
for result in self.test_results:
684+
if result.did_pass and not result.runtime:
685+
msg = (
686+
f"Ignoring test case that passed but had no runtime -> {result.id}, "
687+
f"Loop # {result.loop_index}, Test Type: {result.test_type}, "
688+
f"Verification Type: {result.verification_type}"
689+
)
690+
logger.debug(msg)
691+
692+
usable_runtimes = [
693+
(result.id, result.runtime) for result in self.test_results if result.did_pass and result.runtime
694+
]
695+
return {
696+
usable_id: [runtime[1] for runtime in usable_runtimes if runtime[0] == usable_id]
697+
for usable_id in {runtime[0] for runtime in usable_runtimes}
698+
}
699+
700+
def total_passed_runtime(self) -> int:
701+
"""Calculate the sum of runtimes of all test cases that passed.
702+
703+
A testcase runtime is the minimum value of all looped execution runtimes.
704+
705+
:return: The runtime in nanoseconds.
706+
"""
707+
return sum(
708+
[min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()]
709+
)
710+
711+
def __iter__(self) -> Iterator[FunctionTestInvocation]:
712+
return iter(self.test_results)
713+
714+
def __len__(self) -> int:
715+
return len(self.test_results)
716+
717+
def __getitem__(self, index: int) -> FunctionTestInvocation:
718+
return self.test_results[index]
719+
720+
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
721+
self.test_results[index] = value
722+
723+
def __contains__(self, value: FunctionTestInvocation) -> bool:
724+
return value in self.test_results
725+
726+
def __bool__(self) -> bool:
727+
return bool(self.test_results)
728+
729+
def __eq__(self, other: object) -> bool:
730+
# Unordered comparison
731+
if type(self) is not type(other):
732+
return False
733+
if len(self) != len(other):
734+
return False
735+
original_recursion_limit = sys.getrecursionlimit()
736+
cast(TestResults, other)
737+
for test_result in self:
738+
other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id)
739+
if other_test_result is None:
740+
return False
741+
742+
if original_recursion_limit < 5000:
743+
sys.setrecursionlimit(5000)
744+
if (
745+
test_result.file_name != other_test_result.file_name
746+
or test_result.did_pass != other_test_result.did_pass
747+
or test_result.runtime != other_test_result.runtime
748+
or test_result.test_framework != other_test_result.test_framework
749+
or test_result.test_type != other_test_result.test_type
750+
or not comparator(test_result.return_value, other_test_result.return_value)
751+
):
752+
sys.setrecursionlimit(original_recursion_limit)
753+
return False
754+
sys.setrecursionlimit(original_recursion_limit)
755+
return True

codeflash/optimization/function_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
TestFile,
5757
TestFiles,
5858
TestingMode,
59+
TestResults,
60+
TestType
5961
)
6062
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
6163
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
@@ -65,7 +67,6 @@
6567
from codeflash.verification.equivalence import compare_test_results
6668
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
6769
from codeflash.verification.parse_test_output import parse_test_results
68-
from codeflash.verification.test_results import TestResults, TestType
6970
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
7071
from codeflash.verification.verification_utils import get_test_file_path
7172
from codeflash.verification.verifier import generate_tests

codeflash/optimization/optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
from codeflash.discovery.discover_unit_tests import discover_unit_tests
1717
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
1818
from codeflash.either import is_successful
19-
from codeflash.models.models import ValidCode
19+
from codeflash.models.models import ValidCode, TestType
2020
from codeflash.optimization.function_optimizer import FunctionOptimizer
2121
from codeflash.telemetry.posthog_cf import ph
22-
from codeflash.verification.test_results import TestType
2322
from codeflash.verification.verification_utils import TestConfig
2423

2524
if TYPE_CHECKING:

codeflash/result/critic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
from codeflash.cli_cmds.console import logger
44
from codeflash.code_utils import env_utils
55
from codeflash.code_utils.config_consts import COVERAGE_THRESHOLD, MIN_IMPROVEMENT_THRESHOLD
6-
from codeflash.models.models import CoverageData, OptimizedCandidateResult
7-
from codeflash.verification.test_results import TestType
8-
6+
from codeflash.models.models import CoverageData, OptimizedCandidateResult, TestType
97

108
def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float:
119
"""Calculate the performance gain of an optimized code over the original code.

codeflash/result/explanation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pydantic.dataclasses import dataclass
44

55
from codeflash.code_utils.time_utils import humanize_runtime
6-
from codeflash.verification.test_results import TestResults
6+
from codeflash.models.models import TestResults
77

88

99
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})

codeflash/verification/codeflash_capture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import dill as pickle
1212

13-
from codeflash.verification.test_results import VerificationType
13+
from codeflash.models.models import VerificationType
1414

1515

1616
def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str]:

codeflash/verification/equivalence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from codeflash.cli_cmds.console import console, logger
55
from codeflash.verification.comparator import comparator
6-
from codeflash.verification.test_results import TestResults, TestType, VerificationType
6+
from codeflash.models.models import TestResults, TestType, VerificationType
77

88
INCREASED_RECURSION_LIMIT = 5000
99

0 commit comments

Comments
 (0)