Skip to content

Commit ecfa89f

Browse files
cleaner
1 parent 0b2d894 commit ecfa89f

File tree

1 file changed

+31
-39
lines changed

1 file changed

+31
-39
lines changed

codeflash/verification/equivalence.py

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ class TestDiffScope(Enum):
2525
@dataclass
2626
class TestDiff:
2727
scope: TestDiffScope
28-
pytest_error: str
2928
original_value: any
3029
candidate_value: any
3130
original_pass: bool
3231
candidate_pass: bool
32+
3333
test_src_code: Optional[str] = None
34+
candidate_pytest_error: Optional[str] = None
35+
original_pytest_error: Optional[str] = None
3436

3537

3638
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
@@ -49,15 +51,15 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
4951
original_test_result = original_results.get_by_unique_invocation_loop_id(test_id)
5052
cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id)
5153
candidate_test_failures = candidate_results.test_failures
52-
# original_test_failures = original_results.test_failures
54+
original_test_failures = original_results.test_failures
5355
cdd_pytest_error = (
5456
candidate_test_failures.get(original_test_result.id.test_function_name, "")
5557
if candidate_test_failures
5658
else ""
5759
)
58-
# original_pytest_error = (
59-
# original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else ""
60-
# )
60+
original_pytest_error = (
61+
original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else ""
62+
)
6163

6264
if cdd_test_result is not None and original_test_result is None:
6365
continue
@@ -79,22 +81,26 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
7981
in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO}
8082
):
8183
superset_obj = True
84+
8285
test_src_code = original_test_result.id.get_src_code(original_test_result.file_name)
86+
test_diff = TestDiff(
87+
scope=TestDiffScope.RETURN_VALUE,
88+
original_value=original_test_result.return_value,
89+
candidate_value=cdd_test_result.return_value,
90+
test_src_code=test_src_code,
91+
candidate_pytest_error=cdd_pytest_error,
92+
original_pass=original_test_result.did_pass,
93+
candidate_pass=cdd_test_result.did_pass,
94+
original_pytest_error=original_pytest_error,
95+
)
8396
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
84-
test_diffs.append(
85-
TestDiff(
86-
scope=TestDiffScope.RETURN_VALUE,
87-
test_src_code=test_src_code,
88-
original_value=original_test_result.return_value,
89-
candidate_value=cdd_test_result.return_value,
90-
pytest_error=cdd_pytest_error,
91-
original_pass=original_test_result.did_pass,
92-
candidate_pass=cdd_test_result.did_pass,
93-
)
94-
)
97+
test_diff.scope = TestDiffScope.RETURN_VALUE
98+
test_diff.original_value = original_test_result.return_value
99+
test_diff.candidate_value = cdd_test_result.return_value
100+
test_diffs.append(test_diff)
95101

96102
try:
97-
print(
103+
logger.debug(
98104
f"File Name: {original_test_result.file_name}\n"
99105
f"Test Type: {original_test_result.test_type}\n"
100106
f"Verification Type: {original_test_result.verification_type}\n"
@@ -108,17 +114,10 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
108114
if (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
109115
original_test_result.stdout, cdd_test_result.stdout
110116
):
111-
test_diffs.append(
112-
TestDiff(
113-
scope=TestDiffScope.STDOUT,
114-
test_src_code=test_src_code,
115-
original_value=original_test_result.stdout,
116-
candidate_value=cdd_test_result.stdout,
117-
pytest_error=cdd_pytest_error,
118-
original_pass=original_test_result.did_pass,
119-
candidate_pass=cdd_test_result.did_pass,
120-
)
121-
)
117+
test_diff.scope = TestDiffScope.STDOUT
118+
test_diff.original_value = original_test_result.stdout
119+
test_diff.candidate_value = cdd_test_result.stdout
120+
test_diffs.append(test_diff)
122121
break
123122

124123
if original_test_result.test_type in {
@@ -127,17 +126,10 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
127126
TestType.GENERATED_REGRESSION,
128127
TestType.REPLAY_TEST,
129128
} and (cdd_test_result.did_pass != original_test_result.did_pass):
130-
test_diffs.append(
131-
TestDiff(
132-
scope=TestDiffScope.DID_PASS,
133-
test_src_code=test_src_code,
134-
original_value=original_test_result.did_pass,
135-
candidate_value=cdd_test_result.did_pass,
136-
pytest_error=cdd_pytest_error,
137-
original_pass=original_test_result.did_pass,
138-
candidate_pass=cdd_test_result.did_pass,
139-
)
140-
)
129+
test_diff.scope = TestDiffScope.DID_PASS
130+
test_diff.original_value = original_test_result.did_pass
131+
test_diff.candidate_value = cdd_test_result.did_pass
132+
test_diffs.append(test_diff)
141133
break
142134
sys.setrecursionlimit(original_recursion_limit)
143135
if did_all_timeout:

0 commit comments

Comments
 (0)