Skip to content

Commit a5182c6

Browse files
authored
Merge branch 'main' into alpha-async
2 parents e292c5e + 9bc32ea commit a5182c6

File tree

7 files changed

+310
-46
lines changed

7 files changed

+310
-46
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
<a href="https://github.com/codeflash-ai/codeflash">
44
<img src="https://img.shields.io/github/commit-activity/m/codeflash-ai/codeflash" alt="GitHub commit activity">
55
</a>
6-
<a href="https://pypi.org/project/codeflash/">
7-
<img src="https://img.shields.io/pypi/dm/codeflash" alt="PyPI Downloads">
8-
</a>
6+
<a href="https://pypi.org/project/codeflash/"><img src="https://static.pepy.tech/badge/codeflash" alt="PyPI Downloads"></a>
97
<a href="https://pypi.org/project/codeflash/">
108
<img src="https://img.shields.io/pypi/v/codeflash?label=PyPI%20version" alt="PyPI Downloads">
119
</a>

codeflash/code_utils/code_extractor.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,19 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195
self.last_import_line = self.current_line
196196

197197

198-
class ConditionalImportCollector(cst.CSTVisitor):
199-
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
198+
class DottedImportCollector(cst.CSTVisitor):
199+
"""Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`.
200+
201+
Examples
202+
--------
203+
import os ==> "os"
204+
import dbt.adapters.factory ==> "dbt.adapters.factory"
205+
from pathlib import Path ==> "pathlib.Path"
206+
from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter"
207+
from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional"
208+
from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps"
209+
210+
"""
200211

201212
def __init__(self) -> None:
202213
self.imports: set[str] = set()
@@ -217,7 +228,10 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
217228
for alias in child.names:
218229
module = self.get_full_dotted_name(alias.name)
219230
asname = alias.asname.name.value if alias.asname else alias.name.value
220-
self.imports.add(module if module == asname else f"{module}.{asname}")
231+
if isinstance(asname, cst.Attribute):
232+
self.imports.add(module)
233+
else:
234+
self.imports.add(module if module == asname else f"{module}.{asname}")
221235

222236
elif isinstance(child, cst.ImportFrom):
223237
if child.module is None:
@@ -231,6 +245,7 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
231245

232246
def visit_Module(self, node: cst.Module) -> None:
233247
self.depth = 0
248+
self._collect_imports_from_block(node)
234249

235250
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
236251
self.depth += 1
@@ -388,45 +403,44 @@ def add_needed_imports_from_module(
388403
logger.error(f"Error parsing source module code: {e}")
389404
return dst_module_code
390405

391-
cond_import_collector = ConditionalImportCollector()
406+
dotted_import_collector = DottedImportCollector()
392407
try:
393408
parsed_dst_module = cst.parse_module(dst_module_code)
394-
parsed_dst_module.visit(cond_import_collector)
409+
parsed_dst_module.visit(dotted_import_collector)
395410
except cst.ParserSyntaxError as e:
396411
logger.exception(f"Syntax error in destination module code: {e}")
397412
return dst_module_code # Return the original code if there's a syntax error
398413

399414
try:
400415
for mod in gatherer.module_imports:
401-
if mod in cond_import_collector.imports:
402-
continue
403-
AddImportsVisitor.add_needed_import(dst_context, mod)
416+
if mod not in dotted_import_collector.imports:
417+
AddImportsVisitor.add_needed_import(dst_context, mod)
404418
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
405419
for mod, obj_seq in gatherer.object_mapping.items():
406420
for obj in obj_seq:
407421
if (
408422
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
409423
):
410424
continue # Skip adding imports for helper functions already in the context
411-
if f"{mod}.{obj}" in cond_import_collector.imports:
412-
continue
413-
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
425+
if f"{mod}.{obj}" not in dotted_import_collector.imports:
426+
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
414427
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
415428
except Exception as e:
416429
logger.exception(f"Error adding imports to destination module code: {e}")
417430
return dst_module_code
431+
418432
for mod, asname in gatherer.module_aliases.items():
419-
if f"{mod}.{asname}" in cond_import_collector.imports:
420-
continue
421-
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
433+
if f"{mod}.{asname}" not in dotted_import_collector.imports:
434+
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
422435
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
436+
423437
for mod, alias_pairs in gatherer.alias_mapping.items():
424438
for alias_pair in alias_pairs:
425439
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
426440
continue
427-
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
428-
continue
429-
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
441+
442+
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
443+
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
430444
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
431445

432446
try:

codeflash/code_utils/formatter.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,14 @@ def apply_formatter_cmds(
4444
test_dir_str: Optional[str],
4545
print_status: bool, # noqa
4646
exit_on_failure: bool = True, # noqa
47-
) -> tuple[Path, str]:
48-
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
49-
formatter_name = cmds[0].lower()
47+
) -> tuple[Path, str, bool]:
5048
should_make_copy = False
5149
file_path = path
5250

5351
if test_dir_str:
5452
should_make_copy = True
5553
file_path = Path(test_dir_str) / "temp.py"
5654

57-
if not cmds or formatter_name == "disabled":
58-
return path, path.read_text(encoding="utf8")
59-
6055
if not path.exists():
6156
msg = f"File {path} does not exist. Cannot apply formatter commands."
6257
raise FileNotFoundError(msg)
@@ -66,6 +61,7 @@ def apply_formatter_cmds(
6661

6762
file_token = "$file" # noqa: S105
6863

64+
changed = False
6965
for command in cmds:
7066
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
7167
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
@@ -74,6 +70,7 @@ def apply_formatter_cmds(
7470
if result.returncode == 0:
7571
if print_status:
7672
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
73+
changed = True
7774
else:
7875
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
7976
except FileNotFoundError as e:
@@ -88,7 +85,7 @@ def apply_formatter_cmds(
8885
if exit_on_failure:
8986
raise e from None
9087

91-
return file_path, file_path.read_text(encoding="utf8")
88+
return file_path, file_path.read_text(encoding="utf8"), changed
9289

9390

9491
def get_diff_lines_count(diff_output: str) -> int:
@@ -112,10 +109,16 @@ def format_code(
112109
if console.quiet:
113110
# lsp mode
114111
exit_on_failure = False
115-
with tempfile.TemporaryDirectory() as test_dir_str:
116-
if isinstance(path, str):
117-
path = Path(path)
118112

113+
if isinstance(path, str):
114+
path = Path(path)
115+
116+
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
117+
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
118+
if formatter_name == "disabled":
119+
return path.read_text(encoding="utf8")
120+
121+
with tempfile.TemporaryDirectory() as test_dir_str:
119122
original_code = path.read_text(encoding="utf8")
120123
original_code_lines = len(original_code.split("\n"))
121124

@@ -126,26 +129,39 @@ def format_code(
126129
original_temp = Path(test_dir_str) / "original_temp.py"
127130
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
128131

129-
formatted_temp, formatted_code = apply_formatter_cmds(
130-
formatter_cmds, original_temp, test_dir_str, print_status=False
132+
formatted_temp, formatted_code, changed = apply_formatter_cmds(
133+
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure
131134
)
132135

136+
if not changed:
137+
logger.warning(
138+
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
139+
)
140+
return original_code
141+
133142
diff_output = generate_unified_diff(
134143
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
135144
)
136145
diff_lines_count = get_diff_lines_count(diff_output)
137146

138147
max_diff_lines = min(int(original_code_lines * 0.3), 50)
139148

140-
if diff_lines_count > max_diff_lines and max_diff_lines != -1:
141-
logger.debug(
149+
if diff_lines_count > max_diff_lines:
150+
logger.warning(
142151
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
143152
)
144153
return original_code
154+
145155
# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
146-
_, formatted_code = apply_formatter_cmds(
156+
_, formatted_code, changed = apply_formatter_cmds(
147157
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
148158
)
159+
if not changed:
160+
logger.warning(
161+
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
162+
)
163+
return original_code
164+
149165
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
150166
return formatted_code
151167

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,9 @@ def reformat_code_and_helpers(
739739
file_to_code_context = optimized_context.file_to_path()
740740
optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "")
741741

742-
new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True)
742+
new_code = format_code(
743+
self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False
744+
)
743745
if should_sort_imports:
744746
new_code = sort_imports(new_code)
745747

@@ -748,7 +750,11 @@ def reformat_code_and_helpers(
748750
module_abspath = hp.file_path
749751
hp_source_code = hp.source_code
750752
formatted_helper_code = format_code(
751-
self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True
753+
self.args.formatter_cmds,
754+
module_abspath,
755+
optimized_code=hp_source_code,
756+
check_diff=True,
757+
exit_on_failure=False,
752758
)
753759
if should_sort_imports:
754760
formatted_helper_code = sort_imports(formatted_helper_code)

codeflash/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# These version placeholders will be replaced by uv-dynamic-versioning during build.
2-
__version__ = "0.16.4"
2+
__version__ = "0.16.5"

0 commit comments

Comments
 (0)