Skip to content

Commit b7faf81

Browse files
committed
modify equivalence for hypothesis tests
1 parent 4866d82 commit b7faf81

File tree

2 files changed

+25
-60
lines changed

2 files changed

+25
-60
lines changed

codeflash/verification/equivalence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import defaultdict
33

44
from codeflash.cli_cmds.console import logger
5-
from codeflash.models.models import TestResults, TestType, VerificationType
5+
from codeflash.models.models import FunctionTestInvocation, TestResults, TestType, VerificationType
66
from codeflash.verification.comparator import comparator
77

88
INCREASED_RECURSION_LIMIT = 5000
@@ -139,7 +139,7 @@ def _compare_hypothesis_tests_semantic(original_hypothesis: list, candidate_hypo
139139
"""
140140

141141
# Group by test function (excluding loop index and iteration_id from comparison)
142-
def get_test_key(test_result):
142+
def get_test_key(test_result: FunctionTestInvocation) -> tuple[str, str, str, str]:
143143
"""Get unique key for a Hypothesis test function."""
144144
return (
145145
test_result.id.test_module_path,

codeflash/verification/hypothesis_testing.py

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def filter_hypothesis_tests_by_function_name(code: str, function_name: str) -> s
7575
7676
Returns:
7777
Filtered code with only matching tests
78+
7879
"""
7980
tree = ast.parse(code)
8081

@@ -86,10 +87,9 @@ def visit_Module(self, node): # noqa: ANN001, ANN202
8687
if isinstance(item, (ast.Import, ast.ImportFrom, ast.Assign)):
8788
# Keep all imports and module-level assignments
8889
new_body.append(item)
89-
elif isinstance(item, ast.FunctionDef):
90+
elif isinstance(item, ast.FunctionDef) and item.name.startswith("test_") and function_name in item.name:
9091
# Only keep test functions that match the function name
91-
if item.name.startswith("test_") and function_name in item.name:
92-
new_body.append(item)
92+
new_body.append(item)
9393
node.body = new_body
9494
return node
9595

@@ -126,25 +126,17 @@ def visit_Call(self, node: ast.Call) -> ast.Call:
126126
and node.func.value.id == "st"
127127
):
128128
if node.func.attr == "floats" and not any(
129-
k.arg in ["min_value", "max_value", "allow_nan", "allow_infinity"]
130-
for k in node.keywords
129+
k.arg in ["min_value", "max_value", "allow_nan", "allow_infinity"] for k in node.keywords
131130
):
132131
# Constrain floats to reasonable bounds
133132
node.keywords.extend(
134133
[
135134
ast.keyword(
136-
arg="min_value",
137-
value=ast.UnaryOp(
138-
op=ast.USub(), operand=ast.Constant(value=1e6)
139-
),
135+
arg="min_value", value=ast.UnaryOp(op=ast.USub(), operand=ast.Constant(value=1e6))
140136
),
141137
ast.keyword(arg="max_value", value=ast.Constant(value=1e6)),
142-
ast.keyword(
143-
arg="allow_nan", value=ast.Constant(value=False)
144-
),
145-
ast.keyword(
146-
arg="allow_infinity", value=ast.Constant(value=False)
147-
),
138+
ast.keyword(arg="allow_nan", value=ast.Constant(value=False)),
139+
ast.keyword(arg="allow_infinity", value=ast.Constant(value=False)),
148140
]
149141
)
150142
elif node.func.attr == "integers" and not any(
@@ -154,9 +146,7 @@ def visit_Call(self, node: ast.Call) -> ast.Call:
154146
node.keywords.extend(
155147
[
156148
ast.keyword(arg="min_value", value=ast.Constant(value=-10000)),
157-
ast.keyword(
158-
arg="max_value", value=ast.Constant(value=10000)
159-
),
149+
ast.keyword(arg="max_value", value=ast.Constant(value=10000)),
160150
]
161151
)
162152
return node
@@ -170,39 +160,28 @@ def visit_Call(self, node: ast.Call) -> ast.Call:
170160
(
171161
d
172162
for d in node.decorator_list
173-
if isinstance(d, ast.Call)
174-
and isinstance(d.func, ast.Name)
175-
and d.func.id == "settings"
163+
if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "settings"
176164
),
177165
None,
178166
)
179167

180168
if settings_decorator:
181169
if not any(k.arg == "derandomize" for k in settings_decorator.keywords):
182-
settings_decorator.keywords.append(
183-
ast.keyword(arg="derandomize", value=ast.Constant(value=True))
184-
)
170+
settings_decorator.keywords.append(ast.keyword(arg="derandomize", value=ast.Constant(value=True)))
185171
else:
186172
node.decorator_list.append(
187173
ast.Call(
188174
func=ast.Name(id="settings", ctx=ast.Load()),
189175
args=[],
190-
keywords=[
191-
ast.keyword(
192-
arg="derandomize", value=ast.Constant(value=True)
193-
)
194-
],
176+
keywords=[ast.keyword(arg="derandomize", value=ast.Constant(value=True))],
195177
)
196178
)
197179

198180
return ast.unparse(tree)
199181

200182

201183
def generate_hypothesis_tests(
202-
test_cfg: TestConfig,
203-
args: Namespace,
204-
function_to_optimize: FunctionToOptimize,
205-
function_to_optimize_ast: ast.AST,
184+
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
206185
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
207186
"""Generate property-based tests using Hypothesis ghostwriter.
208187
@@ -223,19 +202,15 @@ def generate_hypothesis_tests(
223202

224203
if (
225204
test_cfg.project_root_path
226-
and isinstance(
227-
function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef)
228-
)
205+
and isinstance(function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef))
229206
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
230207
):
231208
logger.info("Generating Hypothesis tests for the original code…")
232209
console.rule()
233210

234211
try:
235212
qualified_function_path = get_qualified_function_path(
236-
function_to_optimize.file_path,
237-
args.project_root,
238-
function_to_optimize.qualified_name,
213+
function_to_optimize.file_path, args.project_root, function_to_optimize.qualified_name
239214
)
240215
logger.info(f"command: hypothesis write {qualified_function_path}")
241216

@@ -250,9 +225,7 @@ def generate_hypothesis_tests(
250225
except subprocess.TimeoutExpired:
251226
logger.debug("Hypothesis test generation timed out")
252227
end_time = time.perf_counter()
253-
logger.debug(
254-
f"Hypothesis test generation completed in {end_time - start_time:.2f} seconds"
255-
)
228+
logger.debug(f"Hypothesis test generation completed in {end_time - start_time:.2f} seconds")
256229
return function_to_hypothesis_tests, hypothesis_test_suite_code
257230

258231
if hypothesis_result.returncode == 0:
@@ -269,39 +242,33 @@ def generate_hypothesis_tests(
269242
pytest_cmd=args.pytest_cmd,
270243
)
271244
file_to_funcs = {function_to_optimize.file_path: [function_to_optimize]}
272-
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = (
273-
discover_unit_tests(hypothesis_config, file_to_funcs_to_optimize=file_to_funcs)
245+
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = discover_unit_tests(
246+
hypothesis_config, file_to_funcs_to_optimize=file_to_funcs
274247
)
275248
with hypothesis_path.open("r", encoding="utf-8") as f:
276249
original_code = f.read()
277250

278-
unparsed = filter_hypothesis_tests_by_function_name(
279-
original_code, function_to_optimize.function_name
280-
)
251+
unparsed = filter_hypothesis_tests_by_function_name(original_code, function_to_optimize.function_name)
281252

282253
console.print(f"modified src: {unparsed}")
283254

284255
hypothesis_test_suite_code = format_code(
285256
args.formatter_cmds,
286257
hypothesis_path,
287-
optimized_code=make_hypothesis_tests_deterministic(
288-
remove_functions_with_only_any_type(unparsed)
289-
),
258+
optimized_code=make_hypothesis_tests_deterministic(remove_functions_with_only_any_type(unparsed)),
290259
)
291260
with hypothesis_path.open("w", encoding="utf-8") as f:
292261
f.write(hypothesis_test_suite_code)
293-
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = (
294-
discover_unit_tests(hypothesis_config, file_to_funcs_to_optimize=file_to_funcs)
262+
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = discover_unit_tests(
263+
hypothesis_config, file_to_funcs_to_optimize=file_to_funcs
295264
)
296265
logger.info(
297266
f"Created {num_discovered_hypothesis_tests} "
298267
f"hypothesis unit test case{'s' if num_discovered_hypothesis_tests != 1 else ''} "
299268
)
300269
console.rule()
301270
end_time = time.perf_counter()
302-
logger.debug(
303-
f"Generated hypothesis tests in {end_time - start_time:.2f} seconds"
304-
)
271+
logger.debug(f"Generated hypothesis tests in {end_time - start_time:.2f} seconds")
305272
return function_to_hypothesis_tests, hypothesis_test_suite_code
306273

307274
logger.debug(
@@ -310,7 +277,5 @@ def generate_hypothesis_tests(
310277
console.rule()
311278

312279
end_time = time.perf_counter()
313-
logger.debug(
314-
f"Hypothesis test generation completed in {end_time - start_time:.2f} seconds"
315-
)
280+
logger.debug(f"Hypothesis test generation completed in {end_time - start_time:.2f} seconds")
316281
return function_to_hypothesis_tests, hypothesis_test_suite_code

0 commit comments

Comments
 (0)