Skip to content

Commit 6a7da80

Browse files
authored
Merge branch 'main' into feat/detached-worktrees
2 parents d041dfb + f7adce8 commit 6a7da80

File tree

6 files changed

+563
-51
lines changed

6 files changed

+563
-51
lines changed

codeflash/api/cfapi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def suggest_changes(
125125
generated_tests: str,
126126
trace_id: str,
127127
coverage_message: str,
128+
replay_tests: str = "",
129+
concolic_tests: str = "",
128130
) -> Response:
129131
"""Suggest changes to a pull request.
130132
@@ -148,6 +150,8 @@ def suggest_changes(
148150
"generatedTests": generated_tests,
149151
"traceId": trace_id,
150152
"coverage_message": coverage_message,
153+
"replayTests": replay_tests,
154+
"concolicTests": concolic_tests,
151155
}
152156
return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
153157

@@ -162,6 +166,8 @@ def create_pr(
162166
generated_tests: str,
163167
trace_id: str,
164168
coverage_message: str,
169+
replay_tests: str = "",
170+
concolic_tests: str = "",
165171
) -> Response:
166172
"""Create a pull request, targeting the specified branch. (usually 'main').
167173
@@ -184,6 +190,8 @@ def create_pr(
184190
"generatedTests": generated_tests,
185191
"traceId": trace_id,
186192
"coverage_message": coverage_message,
193+
"replayTests": replay_tests,
194+
"concolicTests": concolic_tests,
187195
}
188196
return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)
189197

codeflash/code_utils/code_extractor.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,64 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195
self.last_import_line = self.current_line
196196

197197

198+
class ConditionalImportCollector(cst.CSTVisitor):
199+
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
200+
201+
def __init__(self) -> None:
202+
self.imports: set[str] = set()
203+
self.depth = 0 # top-level
204+
205+
def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
206+
if isinstance(expr, cst.Name):
207+
return expr.value
208+
if isinstance(expr, cst.Attribute):
209+
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
210+
return ""
211+
212+
def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
213+
for statement in block.body:
214+
if isinstance(statement, cst.SimpleStatementLine):
215+
for child in statement.body:
216+
if isinstance(child, cst.Import):
217+
for alias in child.names:
218+
module = self.get_full_dotted_name(alias.name)
219+
asname = alias.asname.name.value if alias.asname else alias.name.value
220+
self.imports.add(module if module == asname else f"{module}.{asname}")
221+
222+
elif isinstance(child, cst.ImportFrom):
223+
if child.module is None:
224+
continue
225+
module = self.get_full_dotted_name(child.module)
226+
for alias in child.names:
227+
if isinstance(alias, cst.ImportAlias):
228+
name = alias.name.value
229+
asname = alias.asname.name.value if alias.asname else name
230+
self.imports.add(f"{module}.{asname}")
231+
232+
def visit_Module(self, node: cst.Module) -> None:
233+
self.depth = 0
234+
235+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
236+
self.depth += 1
237+
238+
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
239+
self.depth -= 1
240+
241+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
242+
self.depth += 1
243+
244+
def leave_ClassDef(self, node: cst.ClassDef) -> None:
245+
self.depth -= 1
246+
247+
def visit_If(self, node: cst.If) -> None:
248+
if self.depth == 0:
249+
self._collect_imports_from_block(node.body)
250+
251+
def visit_Try(self, node: cst.Try) -> None:
252+
if self.depth == 0:
253+
self._collect_imports_from_block(node.body)
254+
255+
198256
class ImportInserter(cst.CSTTransformer):
199257
"""Transformer that inserts global statements after the last import."""
200258

@@ -329,8 +387,19 @@ def add_needed_imports_from_module(
329387
except Exception as e:
330388
logger.error(f"Error parsing source module code: {e}")
331389
return dst_module_code
390+
391+
cond_import_collector = ConditionalImportCollector()
392+
try:
393+
parsed_dst_module = cst.parse_module(dst_module_code)
394+
parsed_dst_module.visit(cond_import_collector)
395+
except cst.ParserSyntaxError as e:
396+
logger.exception(f"Syntax error in destination module code: {e}")
397+
return dst_module_code # Return the original code if there's a syntax error
398+
332399
try:
333400
for mod in gatherer.module_imports:
401+
if mod in cond_import_collector.imports:
402+
continue
334403
AddImportsVisitor.add_needed_import(dst_context, mod)
335404
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
336405
for mod, obj_seq in gatherer.object_mapping.items():
@@ -339,28 +408,29 @@ def add_needed_imports_from_module(
339408
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
340409
):
341410
continue # Skip adding imports for helper functions already in the context
411+
if f"{mod}.{obj}" in cond_import_collector.imports:
412+
continue
342413
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
343414
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
344415
except Exception as e:
345416
logger.exception(f"Error adding imports to destination module code: {e}")
346417
return dst_module_code
347418
for mod, asname in gatherer.module_aliases.items():
419+
if f"{mod}.{asname}" in cond_import_collector.imports:
420+
continue
348421
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
349422
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
350423
for mod, alias_pairs in gatherer.alias_mapping.items():
351424
for alias_pair in alias_pairs:
352425
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
353426
continue
427+
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
428+
continue
354429
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
355430
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
356431

357432
try:
358-
parsed_module = cst.parse_module(dst_module_code)
359-
except cst.ParserSyntaxError as e:
360-
logger.exception(f"Syntax error in destination module code: {e}")
361-
return dst_module_code # Return the original code if there's a syntax error
362-
try:
363-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
433+
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
364434
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
365435
return transformed_module.code.lstrip("\n")
366436
except Exception as e:

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,7 @@ def process_review(
11961196
if concolic_test_str:
11971197
generated_tests_str += "\n#------------------------------------------------\n" + concolic_test_str
11981198

1199-
existing_tests = existing_tests_source_for(
1199+
existing_tests, replay_tests, concolic_tests = existing_tests_source_for(
12001200
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
12011201
function_to_all_tests,
12021202
test_cfg=self.test_cfg,
@@ -1237,6 +1237,8 @@ def process_review(
12371237
if self.experiment_id
12381238
else self.function_trace_id,
12391239
"coverage_message": coverage_message,
1240+
"replay_tests": replay_tests,
1241+
"concolic_tests": concolic_tests,
12401242
}
12411243

12421244
raise_pr = not self.args.no_pr

codeflash/result/create_pr.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,16 @@ def existing_tests_source_for(
3434
test_cfg: TestConfig,
3535
original_runtimes_all: dict[InvocationId, list[int]],
3636
optimized_runtimes_all: dict[InvocationId, list[int]],
37-
) -> str:
37+
) -> tuple[str, str, str]:
3838
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)
3939
if not test_files:
40-
return ""
41-
output: str = ""
42-
rows = []
40+
return "", "", ""
41+
output_existing: str = ""
42+
output_concolic: str = ""
43+
output_replay: str = ""
44+
rows_existing = []
45+
rows_concolic = []
46+
rows_replay = []
4347
headers = ["Test File::Test Function", "Original ⏱️", "Optimized ⏱️", "Speedup"]
4448
tests_root = test_cfg.tests_root
4549
original_tests_to_runtimes: dict[Path, dict[str, int]] = {}
@@ -99,28 +103,79 @@ def existing_tests_source_for(
99103
* 100
100104
)
101105
if greater:
102-
rows.append(
106+
if "__replay_test_" in str(print_filename):
107+
rows_replay.append(
108+
[
109+
f"`{print_filename}::{qualified_name}`",
110+
f"{print_original_runtime}",
111+
f"{print_optimized_runtime}",
112+
f"{perf_gain}%⚠️",
113+
]
114+
)
115+
elif "codeflash_concolic" in str(print_filename):
116+
rows_concolic.append(
117+
[
118+
f"`{print_filename}::{qualified_name}`",
119+
f"{print_original_runtime}",
120+
f"{print_optimized_runtime}",
121+
f"{perf_gain}%⚠️",
122+
]
123+
)
124+
else:
125+
rows_existing.append(
126+
[
127+
f"`{print_filename}::{qualified_name}`",
128+
f"{print_original_runtime}",
129+
f"{print_optimized_runtime}",
130+
f"{perf_gain}%⚠️",
131+
]
132+
)
133+
elif "__replay_test_" in str(print_filename):
134+
rows_replay.append(
103135
[
104136
f"`{print_filename}::{qualified_name}`",
105137
f"{print_original_runtime}",
106138
f"{print_optimized_runtime}",
107-
f"⚠️{perf_gain}%",
139+
f"{perf_gain}%✅",
140+
]
141+
)
142+
elif "codeflash_concolic" in str(print_filename):
143+
rows_concolic.append(
144+
[
145+
f"`{print_filename}::{qualified_name}`",
146+
f"{print_original_runtime}",
147+
f"{print_optimized_runtime}",
148+
f"{perf_gain}%✅",
108149
]
109150
)
110151
else:
111-
rows.append(
152+
rows_existing.append(
112153
[
113154
f"`{print_filename}::{qualified_name}`",
114155
f"{print_original_runtime}",
115156
f"{print_optimized_runtime}",
116-
f"{perf_gain}%",
157+
f"{perf_gain}%",
117158
]
118159
)
119-
output += tabulate( # type: ignore[no-untyped-call]
120-
headers=headers, tabular_data=rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
160+
output_existing += tabulate( # type: ignore[no-untyped-call]
161+
headers=headers, tabular_data=rows_existing, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
162+
)
163+
output_existing += "\n"
164+
if len(rows_existing) == 0:
165+
output_existing = ""
166+
output_concolic += tabulate( # type: ignore[no-untyped-call]
167+
headers=headers, tabular_data=rows_concolic, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
168+
)
169+
output_concolic += "\n"
170+
if len(rows_concolic) == 0:
171+
output_concolic = ""
172+
output_replay += tabulate( # type: ignore[no-untyped-call]
173+
headers=headers, tabular_data=rows_replay, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
121174
)
122-
output += "\n"
123-
return output
175+
output_replay += "\n"
176+
if len(rows_replay) == 0:
177+
output_replay = ""
178+
return output_existing, output_replay, output_concolic
124179

125180

126181
def check_create_pr(
@@ -131,6 +186,8 @@ def check_create_pr(
131186
generated_original_test_source: str,
132187
function_trace_id: str,
133188
coverage_message: str,
189+
replay_tests: str,
190+
concolic_tests: str,
134191
git_remote: Optional[str] = None,
135192
) -> None:
136193
pr_number: Optional[int] = env_utils.get_pr_number()
@@ -171,6 +228,8 @@ def check_create_pr(
171228
generated_tests=generated_original_test_source,
172229
trace_id=function_trace_id,
173230
coverage_message=coverage_message,
231+
replay_tests=replay_tests,
232+
concolic_tests=concolic_tests,
174233
)
175234
if response.ok:
176235
logger.info(f"Suggestions were successfully made to PR #{pr_number}")
@@ -218,6 +277,8 @@ def check_create_pr(
218277
generated_tests=generated_original_test_source,
219278
trace_id=function_trace_id,
220279
coverage_message=coverage_message,
280+
replay_tests=replay_tests,
281+
concolic_tests=concolic_tests,
221282
)
222283
if response.ok:
223284
pr_id = response.text

0 commit comments

Comments
 (0)