@@ -63,8 +63,43 @@ def remove_functions_with_only_any_type(code_string: str) -> str:
6363 return ast .unparse (new_tree )
6464
6565
66+ def filter_hypothesis_tests_by_function_name (code : str , function_name : str ) -> str :
67+ """Filter hypothesis tests to only include tests matching the function name.
68+
69+ Preserves all imports, module-level assignments, and only test functions
70+ that contain the target function name.
71+
72+ Args:
73+ code: The hypothesis test code to filter
74+ function_name: The name of the function being tested
75+
76+ Returns:
77+ Filtered code with only matching tests
78+ """
79+ tree = ast .parse (code )
80+
81+ class TestFunctionRemover (ast .NodeTransformer ):
82+ def visit_Module (self , node ): # noqa: ANN001, ANN202
83+ # Filter body to keep imports, module-level assignments, and matching test functions
84+ new_body = []
85+ for item in node .body :
86+ if isinstance (item , (ast .Import , ast .ImportFrom , ast .Assign )):
87+ # Keep all imports and module-level assignments
88+ new_body .append (item )
89+ elif isinstance (item , ast .FunctionDef ):
90+ # Only keep test functions that match the function name
91+ if item .name .startswith ("test_" ) and function_name in item .name :
92+ new_body .append (item )
93+ node .body = new_body
94+ return node
95+
96+ modified_tree = TestFunctionRemover ().visit (tree )
97+ ast .fix_missing_locations (modified_tree )
98+ return ast .unparse (modified_tree )
99+
100+
66101def make_hypothesis_tests_deterministic (code : str ) -> str :
67- """Add @settings(derandomize=True) decorator to make Hypothesis tests deterministic."""
102+ """Add @settings(derandomize=True) decorator and constrain strategies to make Hypothesis tests deterministic."""
68103 try :
69104 tree = ast .parse (code )
70105 except SyntaxError :
@@ -80,34 +115,94 @@ def make_hypothesis_tests_deterministic(code: str) -> str:
80115 if not settings_imported :
81116 tree .body .insert (0 , ast .parse ("from hypothesis import settings" ).body [0 ])
82117
118+ class StrategyConstrainer (ast .NodeTransformer ):
119+ def visit_Call (self , node : ast .Call ) -> ast .Call :
120+ self .generic_visit (node )
121+
122+ # Check if this is a strategy call (st.floats(), st.integers(), etc.)
123+ if (
124+ isinstance (node .func , ast .Attribute )
125+ and isinstance (node .func .value , ast .Name )
126+ and node .func .value .id == "st"
127+ ):
128+ if node .func .attr == "floats" and not any (
129+ k .arg in ["min_value" , "max_value" , "allow_nan" , "allow_infinity" ]
130+ for k in node .keywords
131+ ):
132+ # Constrain floats to reasonable bounds
133+ node .keywords .extend (
134+ [
135+ ast .keyword (
136+ arg = "min_value" ,
137+ value = ast .UnaryOp (
138+ op = ast .USub (), operand = ast .Constant (value = 1e6 )
139+ ),
140+ ),
141+ ast .keyword (arg = "max_value" , value = ast .Constant (value = 1e6 )),
142+ ast .keyword (
143+ arg = "allow_nan" , value = ast .Constant (value = False )
144+ ),
145+ ast .keyword (
146+ arg = "allow_infinity" , value = ast .Constant (value = False )
147+ ),
148+ ]
149+ )
150+ elif node .func .attr == "integers" and not any (
151+ k .arg in ["min_value" , "max_value" ] for k in node .keywords
152+ ):
153+ # Constrain integers to reasonable bounds
154+ node .keywords .extend (
155+ [
156+ ast .keyword (arg = "min_value" , value = ast .Constant (value = 0 )),
157+ ast .keyword (
158+ arg = "max_value" , value = ast .Constant (value = 10000 )
159+ ),
160+ ]
161+ )
162+ return node
163+
164+ tree = StrategyConstrainer ().visit (tree )
165+ ast .fix_missing_locations (tree )
166+
83167 for node in ast .walk (tree ):
84168 if isinstance (node , ast .FunctionDef ):
85169 settings_decorator = next (
86170 (
87171 d
88172 for d in node .decorator_list
89- if isinstance (d , ast .Call ) and isinstance (d .func , ast .Name ) and d .func .id == "settings"
173+ if isinstance (d , ast .Call )
174+ and isinstance (d .func , ast .Name )
175+ and d .func .id == "settings"
90176 ),
91177 None ,
92178 )
93179
94180 if settings_decorator :
95181 if not any (k .arg == "derandomize" for k in settings_decorator .keywords ):
96- settings_decorator .keywords .append (ast .keyword (arg = "derandomize" , value = ast .Constant (value = True )))
182+ settings_decorator .keywords .append (
183+ ast .keyword (arg = "derandomize" , value = ast .Constant (value = True ))
184+ )
97185 else :
98186 node .decorator_list .append (
99187 ast .Call (
100188 func = ast .Name (id = "settings" , ctx = ast .Load ()),
101189 args = [],
102- keywords = [ast .keyword (arg = "derandomize" , value = ast .Constant (value = True ))],
190+ keywords = [
191+ ast .keyword (
192+ arg = "derandomize" , value = ast .Constant (value = True )
193+ )
194+ ],
103195 )
104196 )
105197
106198 return ast .unparse (tree )
107199
108200
109201def generate_hypothesis_tests (
110- test_cfg : TestConfig , args : Namespace , function_to_optimize : FunctionToOptimize , function_to_optimize_ast : ast .AST
202+ test_cfg : TestConfig ,
203+ args : Namespace ,
204+ function_to_optimize : FunctionToOptimize ,
205+ function_to_optimize_ast : ast .AST ,
111206) -> tuple [dict [str , list [FunctionCalledInTest ]], str ]:
112207 """Generate property-based tests using Hypothesis ghostwriter.
113208
@@ -128,15 +223,19 @@ def generate_hypothesis_tests(
128223
129224 if (
130225 test_cfg .project_root_path
131- and isinstance (function_to_optimize_ast , (ast .FunctionDef , ast .AsyncFunctionDef ))
226+ and isinstance (
227+ function_to_optimize_ast , (ast .FunctionDef , ast .AsyncFunctionDef )
228+ )
132229 and has_typed_parameters (function_to_optimize_ast , function_to_optimize .parents )
133230 ):
134231 logger .info ("Generating Hypothesis tests for the original code…" )
135232 console .rule ()
136233
137234 try :
138235 qualified_function_path = get_qualified_function_path (
139- function_to_optimize .file_path , args .project_root , function_to_optimize .qualified_name
236+ function_to_optimize .file_path ,
237+ args .project_root ,
238+ function_to_optimize .qualified_name ,
140239 )
141240 logger .info (f"command: hypothesis write { qualified_function_path } " )
142241
@@ -151,7 +250,9 @@ def generate_hypothesis_tests(
151250 except subprocess .TimeoutExpired :
152251 logger .debug ("Hypothesis test generation timed out" )
153252 end_time = time .perf_counter ()
154- logger .debug (f"Hypothesis test generation completed in { end_time - start_time :.2f} seconds" )
253+ logger .debug (
254+ f"Hypothesis test generation completed in { end_time - start_time :.2f} seconds"
255+ )
155256 return function_to_hypothesis_tests , hypothesis_test_suite_code
156257
157258 if hypothesis_result .returncode == 0 :
@@ -167,37 +268,39 @@ def generate_hypothesis_tests(
167268 test_framework = args .test_framework ,
168269 pytest_cmd = args .pytest_cmd ,
169270 )
170- function_to_hypothesis_tests , num_discovered_hypothesis_tests , _ = discover_unit_tests (hypothesis_config )
271+ function_to_hypothesis_tests , num_discovered_hypothesis_tests , _ = (
272+ discover_unit_tests (hypothesis_config )
273+ )
171274 with hypothesis_path .open ("r" , encoding = "utf-8" ) as f :
172- tree = ast .parse (f .read ())
173-
174- class TestFunctionRemover (ast .NodeTransformer ):
175- def visit_FunctionDef (self , node ): # noqa: ANN001, ANN202
176- if node .name .startswith ("test_" ) and function_to_optimize .function_name in node .name :
177- return node
178- return None
275+ original_code = f .read ()
179276
180- modified_tree = TestFunctionRemover (). visit ( tree )
181- ast . fix_missing_locations ( modified_tree )
182- unparsed = ast . unparse ( modified_tree )
277+ unparsed = filter_hypothesis_tests_by_function_name (
278+ original_code , function_to_optimize . function_name
279+ )
183280
184281 console .print (f"modified src: { unparsed } " )
185282
186283 hypothesis_test_suite_code = format_code (
187284 args .formatter_cmds ,
188285 hypothesis_path ,
189- optimized_code = make_hypothesis_tests_deterministic (remove_functions_with_only_any_type (unparsed )),
286+ optimized_code = make_hypothesis_tests_deterministic (
287+ remove_functions_with_only_any_type (unparsed )
288+ ),
190289 )
191290 with hypothesis_path .open ("w" , encoding = "utf-8" ) as f :
192291 f .write (hypothesis_test_suite_code )
193- function_to_hypothesis_tests , num_discovered_hypothesis_tests , _ = discover_unit_tests (hypothesis_config )
292+ function_to_hypothesis_tests , num_discovered_hypothesis_tests , _ = (
293+ discover_unit_tests (hypothesis_config )
294+ )
194295 logger .info (
195296 f"Created { num_discovered_hypothesis_tests } "
196297 f"hypothesis unit test case{ 's' if num_discovered_hypothesis_tests != 1 else '' } "
197298 )
198299 console .rule ()
199300 end_time = time .perf_counter ()
200- logger .debug (f"Generated hypothesis tests in { end_time - start_time :.2f} seconds" )
301+ logger .debug (
302+ f"Generated hypothesis tests in { end_time - start_time :.2f} seconds"
303+ )
201304 return function_to_hypothesis_tests , hypothesis_test_suite_code
202305
203306 logger .debug (
@@ -206,5 +309,7 @@ def visit_FunctionDef(self, node): # noqa: ANN001, ANN202
206309 console .rule ()
207310
208311 end_time = time .perf_counter ()
209- logger .debug (f"Hypothesis test generation completed in { end_time - start_time :.2f} seconds" )
312+ logger .debug (
313+ f"Hypothesis test generation completed in { end_time - start_time :.2f} seconds"
314+ )
210315 return function_to_hypothesis_tests , hypothesis_test_suite_code
0 commit comments