Skip to content

Commit 7ee1ab1

Browse files
committed
cleanup
1 parent bfe4179 commit 7ee1ab1

File tree

1 file changed

+128
-23
lines changed

1 file changed

+128
-23
lines changed

codeflash/verification/hypothesis_testing.py

Lines changed: 128 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,43 @@ def remove_functions_with_only_any_type(code_string: str) -> str:
6363
return ast.unparse(new_tree)
6464

6565

66+
def filter_hypothesis_tests_by_function_name(code: str, function_name: str) -> str:
67+
"""Filter hypothesis tests to only include tests matching the function name.
68+
69+
Preserves all imports, module-level assignments, and only test functions
70+
that contain the target function name.
71+
72+
Args:
73+
code: The hypothesis test code to filter
74+
function_name: The name of the function being tested
75+
76+
Returns:
77+
Filtered code with only matching tests
78+
"""
79+
tree = ast.parse(code)
80+
81+
class TestFunctionRemover(ast.NodeTransformer):
82+
def visit_Module(self, node): # noqa: ANN001, ANN202
83+
# Filter body to keep imports, module-level assignments, and matching test functions
84+
new_body = []
85+
for item in node.body:
86+
if isinstance(item, (ast.Import, ast.ImportFrom, ast.Assign)):
87+
# Keep all imports and module-level assignments
88+
new_body.append(item)
89+
elif isinstance(item, ast.FunctionDef):
90+
# Only keep test functions that match the function name
91+
if item.name.startswith("test_") and function_name in item.name:
92+
new_body.append(item)
93+
node.body = new_body
94+
return node
95+
96+
modified_tree = TestFunctionRemover().visit(tree)
97+
ast.fix_missing_locations(modified_tree)
98+
return ast.unparse(modified_tree)
99+
100+
66101
def make_hypothesis_tests_deterministic(code: str) -> str:
67-
"""Add @settings(derandomize=True) decorator to make Hypothesis tests deterministic."""
102+
"""Add @settings(derandomize=True) decorator and constrain strategies to make Hypothesis tests deterministic."""
68103
try:
69104
tree = ast.parse(code)
70105
except SyntaxError:
@@ -80,34 +115,94 @@ def make_hypothesis_tests_deterministic(code: str) -> str:
80115
if not settings_imported:
81116
tree.body.insert(0, ast.parse("from hypothesis import settings").body[0])
82117

118+
class StrategyConstrainer(ast.NodeTransformer):
119+
def visit_Call(self, node: ast.Call) -> ast.Call:
120+
self.generic_visit(node)
121+
122+
# Check if this is a strategy call (st.floats(), st.integers(), etc.)
123+
if (
124+
isinstance(node.func, ast.Attribute)
125+
and isinstance(node.func.value, ast.Name)
126+
and node.func.value.id == "st"
127+
):
128+
if node.func.attr == "floats" and not any(
129+
k.arg in ["min_value", "max_value", "allow_nan", "allow_infinity"]
130+
for k in node.keywords
131+
):
132+
# Constrain floats to reasonable bounds
133+
node.keywords.extend(
134+
[
135+
ast.keyword(
136+
arg="min_value",
137+
value=ast.UnaryOp(
138+
op=ast.USub(), operand=ast.Constant(value=1e6)
139+
),
140+
),
141+
ast.keyword(arg="max_value", value=ast.Constant(value=1e6)),
142+
ast.keyword(
143+
arg="allow_nan", value=ast.Constant(value=False)
144+
),
145+
ast.keyword(
146+
arg="allow_infinity", value=ast.Constant(value=False)
147+
),
148+
]
149+
)
150+
elif node.func.attr == "integers" and not any(
151+
k.arg in ["min_value", "max_value"] for k in node.keywords
152+
):
153+
# Constrain integers to reasonable bounds
154+
node.keywords.extend(
155+
[
156+
ast.keyword(arg="min_value", value=ast.Constant(value=0)),
157+
ast.keyword(
158+
arg="max_value", value=ast.Constant(value=10000)
159+
),
160+
]
161+
)
162+
return node
163+
164+
tree = StrategyConstrainer().visit(tree)
165+
ast.fix_missing_locations(tree)
166+
83167
for node in ast.walk(tree):
84168
if isinstance(node, ast.FunctionDef):
85169
settings_decorator = next(
86170
(
87171
d
88172
for d in node.decorator_list
89-
if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "settings"
173+
if isinstance(d, ast.Call)
174+
and isinstance(d.func, ast.Name)
175+
and d.func.id == "settings"
90176
),
91177
None,
92178
)
93179

94180
if settings_decorator:
95181
if not any(k.arg == "derandomize" for k in settings_decorator.keywords):
96-
settings_decorator.keywords.append(ast.keyword(arg="derandomize", value=ast.Constant(value=True)))
182+
settings_decorator.keywords.append(
183+
ast.keyword(arg="derandomize", value=ast.Constant(value=True))
184+
)
97185
else:
98186
node.decorator_list.append(
99187
ast.Call(
100188
func=ast.Name(id="settings", ctx=ast.Load()),
101189
args=[],
102-
keywords=[ast.keyword(arg="derandomize", value=ast.Constant(value=True))],
190+
keywords=[
191+
ast.keyword(
192+
arg="derandomize", value=ast.Constant(value=True)
193+
)
194+
],
103195
)
104196
)
105197

106198
return ast.unparse(tree)
107199

108200

109201
def generate_hypothesis_tests(
110-
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
202+
test_cfg: TestConfig,
203+
args: Namespace,
204+
function_to_optimize: FunctionToOptimize,
205+
function_to_optimize_ast: ast.AST,
111206
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
112207
"""Generate property-based tests using Hypothesis ghostwriter.
113208
@@ -128,15 +223,19 @@ def generate_hypothesis_tests(
128223

129224
if (
130225
test_cfg.project_root_path
131-
and isinstance(function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef))
226+
and isinstance(
227+
function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef)
228+
)
132229
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
133230
):
134231
logger.info("Generating Hypothesis tests for the original code…")
135232
console.rule()
136233

137234
try:
138235
qualified_function_path = get_qualified_function_path(
139-
function_to_optimize.file_path, args.project_root, function_to_optimize.qualified_name
236+
function_to_optimize.file_path,
237+
args.project_root,
238+
function_to_optimize.qualified_name,
140239
)
141240
logger.info(f"command: hypothesis write {qualified_function_path}")
142241

@@ -151,7 +250,9 @@ def generate_hypothesis_tests(
151250
except subprocess.TimeoutExpired:
152251
logger.debug("Hypothesis test generation timed out")
153252
end_time = time.perf_counter()
154-
logger.debug(f"Hypothesis test generation completed in {end_time - start_time:.2f} seconds")
253+
logger.debug(
254+
f"Hypothesis test generation completed in {end_time - start_time:.2f} seconds"
255+
)
155256
return function_to_hypothesis_tests, hypothesis_test_suite_code
156257

157258
if hypothesis_result.returncode == 0:
@@ -167,37 +268,39 @@ def generate_hypothesis_tests(
167268
test_framework=args.test_framework,
168269
pytest_cmd=args.pytest_cmd,
169270
)
170-
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = discover_unit_tests(hypothesis_config)
271+
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = (
272+
discover_unit_tests(hypothesis_config)
273+
)
171274
with hypothesis_path.open("r", encoding="utf-8") as f:
172-
tree = ast.parse(f.read())
173-
174-
class TestFunctionRemover(ast.NodeTransformer):
175-
def visit_FunctionDef(self, node): # noqa: ANN001, ANN202
176-
if node.name.startswith("test_") and function_to_optimize.function_name in node.name:
177-
return node
178-
return None
275+
original_code = f.read()
179276

180-
modified_tree = TestFunctionRemover().visit(tree)
181-
ast.fix_missing_locations(modified_tree)
182-
unparsed = ast.unparse(modified_tree)
277+
unparsed = filter_hypothesis_tests_by_function_name(
278+
original_code, function_to_optimize.function_name
279+
)
183280

184281
console.print(f"modified src: {unparsed}")
185282

186283
hypothesis_test_suite_code = format_code(
187284
args.formatter_cmds,
188285
hypothesis_path,
189-
optimized_code=make_hypothesis_tests_deterministic(remove_functions_with_only_any_type(unparsed)),
286+
optimized_code=make_hypothesis_tests_deterministic(
287+
remove_functions_with_only_any_type(unparsed)
288+
),
190289
)
191290
with hypothesis_path.open("w", encoding="utf-8") as f:
192291
f.write(hypothesis_test_suite_code)
193-
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = discover_unit_tests(hypothesis_config)
292+
function_to_hypothesis_tests, num_discovered_hypothesis_tests, _ = (
293+
discover_unit_tests(hypothesis_config)
294+
)
194295
logger.info(
195296
f"Created {num_discovered_hypothesis_tests} "
196297
f"hypothesis unit test case{'s' if num_discovered_hypothesis_tests != 1 else ''} "
197298
)
198299
console.rule()
199300
end_time = time.perf_counter()
200-
logger.debug(f"Generated hypothesis tests in {end_time - start_time:.2f} seconds")
301+
logger.debug(
302+
f"Generated hypothesis tests in {end_time - start_time:.2f} seconds"
303+
)
201304
return function_to_hypothesis_tests, hypothesis_test_suite_code
202305

203306
logger.debug(
@@ -206,5 +309,7 @@ def visit_FunctionDef(self, node): # noqa: ANN001, ANN202
206309
console.rule()
207310

208311
end_time = time.perf_counter()
209-
logger.debug(f"Hypothesis test generation completed in {end_time - start_time:.2f} seconds")
312+
logger.debug(
313+
f"Hypothesis test generation completed in {end_time - start_time:.2f} seconds"
314+
)
210315
return function_to_hypothesis_tests, hypothesis_test_suite_code

0 commit comments

Comments
 (0)