Skip to content

Commit f766ac6

Browse files
Merge branch 'main' of github.com:codeflash-ai/codeflash into lsp/verbose-quiet-logs
2 parents 8d54ab6 + 9ac5d34 commit f766ac6

File tree

13 files changed

+1820
-821
lines changed

13 files changed

+1820
-821
lines changed

codeflash/api/aiservice.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,51 @@ def get_new_explanation( # noqa: D417
349349
console.rule()
350350
return ""
351351

352+
def generate_ranking( # noqa: D417
353+
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[float]
354+
) -> list[int] | None:
355+
"""Optimize the given python code for performance by making a request to the Django endpoint.
356+
357+
Parameters
358+
----------
359+
- trace_id : unique uuid of function
360+
- diffs : list of unified diff strings of opt candidates
361+
- speedups : list of speedups of opt candidates
362+
363+
Returns
364+
-------
365+
- List[int]: Ranking of opt candidates in decreasing order
366+
367+
"""
368+
payload = {
369+
"trace_id": trace_id,
370+
"diffs": diffs,
371+
"speedups": speedups,
372+
"optimization_ids": optimization_ids,
373+
"python_version": platform.python_version(),
374+
}
375+
logger.info("Generating ranking")
376+
console.rule()
377+
try:
378+
response = self.make_ai_service_request("/rank", payload=payload, timeout=60)
379+
except requests.exceptions.RequestException as e:
380+
logger.exception(f"Error generating ranking: {e}")
381+
ph("cli-optimize-error-caught", {"error": str(e)})
382+
return None
383+
384+
if response.status_code == 200:
385+
ranking: list[int] = response.json()["ranking"]
386+
console.rule()
387+
return ranking
388+
try:
389+
error = response.json()["error"]
390+
except Exception:
391+
error = response.text
392+
logger.error(f"Error generating ranking: {response.status_code} - {error}")
393+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
394+
console.rule()
395+
return None
396+
352397
def log_results( # noqa: D417
353398
self,
354399
function_trace_id: str,

codeflash/code_utils/code_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,23 @@
2020
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2121

2222

23+
def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
24+
"""Return the unified diff between two code strings as a single string.
25+
26+
:param code1: First code string (original).
27+
:param code2: Second code string (modified).
28+
:param fromfile: Label for the first code string.
29+
:param tofile: Label for the second code string.
30+
:return: Unified diff as a string.
31+
"""
32+
code1_lines = code1.splitlines(keepends=True)
33+
code2_lines = code2.splitlines(keepends=True)
34+
35+
diff = difflib.unified_diff(code1_lines, code2_lines, fromfile=fromfile, tofile=tofile, lineterm="")
36+
37+
return "".join(diff)
38+
39+
2340
def diff_length(a: str, b: str) -> int:
2441
"""Compute the length (in characters) of the unified diff between two strings.
2542
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import ast
2+
import hashlib
3+
from typing import Dict, Set
4+
5+
6+
class VariableNormalizer(ast.NodeTransformer):
7+
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
8+
Preserves function names, class names, parameters, built-ins, and imported names.
9+
"""
10+
11+
def __init__(self):
12+
self.var_counter = 0
13+
self.var_mapping: Dict[str, str] = {}
14+
self.scope_stack = []
15+
self.builtins = set(dir(__builtins__))
16+
self.imports: Set[str] = set()
17+
self.global_vars: Set[str] = set()
18+
self.nonlocal_vars: Set[str] = set()
19+
self.parameters: Set[str] = set() # Track function parameters
20+
21+
def enter_scope(self):
22+
"""Enter a new scope (function/class)"""
23+
self.scope_stack.append(
24+
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
25+
)
26+
27+
def exit_scope(self):
28+
"""Exit current scope and restore parent scope"""
29+
if self.scope_stack:
30+
scope = self.scope_stack.pop()
31+
self.var_mapping = scope["var_mapping"]
32+
self.var_counter = scope["var_counter"]
33+
self.parameters = scope["parameters"]
34+
35+
def get_normalized_name(self, name: str) -> str:
36+
"""Get or create normalized name for a variable"""
37+
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
38+
if (
39+
name in self.builtins
40+
or name in self.imports
41+
or name in self.global_vars
42+
or name in self.nonlocal_vars
43+
or name in self.parameters
44+
):
45+
return name
46+
47+
# Only normalize local variables
48+
if name not in self.var_mapping:
49+
self.var_mapping[name] = f"var_{self.var_counter}"
50+
self.var_counter += 1
51+
return self.var_mapping[name]
52+
53+
def visit_Import(self, node):
54+
"""Track imported names"""
55+
for alias in node.names:
56+
name = alias.asname if alias.asname else alias.name
57+
self.imports.add(name.split(".")[0])
58+
return node
59+
60+
def visit_ImportFrom(self, node):
61+
"""Track imported names from modules"""
62+
for alias in node.names:
63+
name = alias.asname if alias.asname else alias.name
64+
self.imports.add(name)
65+
return node
66+
67+
def visit_Global(self, node):
68+
"""Track global variable declarations"""
69+
# Avoid repeated .add calls by using set.update with list
70+
self.global_vars.update(node.names)
71+
return node
72+
73+
def visit_Nonlocal(self, node):
74+
"""Track nonlocal variable declarations"""
75+
for name in node.names:
76+
self.nonlocal_vars.add(name)
77+
return node
78+
79+
def visit_FunctionDef(self, node):
80+
"""Process function but keep function name and parameters unchanged"""
81+
self.enter_scope()
82+
83+
# Track all parameters (don't modify them)
84+
for arg in node.args.args:
85+
self.parameters.add(arg.arg)
86+
if node.args.vararg:
87+
self.parameters.add(node.args.vararg.arg)
88+
if node.args.kwarg:
89+
self.parameters.add(node.args.kwarg.arg)
90+
for arg in node.args.kwonlyargs:
91+
self.parameters.add(arg.arg)
92+
93+
# Visit function body
94+
node = self.generic_visit(node)
95+
self.exit_scope()
96+
return node
97+
98+
def visit_AsyncFunctionDef(self, node):
99+
"""Handle async functions same as regular functions"""
100+
return self.visit_FunctionDef(node)
101+
102+
def visit_ClassDef(self, node):
103+
"""Process class but keep class name unchanged"""
104+
self.enter_scope()
105+
node = self.generic_visit(node)
106+
self.exit_scope()
107+
return node
108+
109+
def visit_Name(self, node):
110+
"""Normalize variable names in Name nodes"""
111+
if isinstance(node.ctx, (ast.Store, ast.Del)):
112+
# For assignments and deletions, check if we should normalize
113+
if (
114+
node.id not in self.builtins
115+
and node.id not in self.imports
116+
and node.id not in self.parameters
117+
and node.id not in self.global_vars
118+
and node.id not in self.nonlocal_vars
119+
):
120+
node.id = self.get_normalized_name(node.id)
121+
elif isinstance(node.ctx, ast.Load):
122+
# For loading, use existing mapping if available
123+
if node.id in self.var_mapping:
124+
node.id = self.var_mapping[node.id]
125+
return node
126+
127+
def visit_ExceptHandler(self, node):
128+
"""Normalize exception variable names"""
129+
if node.name:
130+
node.name = self.get_normalized_name(node.name)
131+
return self.generic_visit(node)
132+
133+
def visit_comprehension(self, node):
134+
"""Normalize comprehension target variables"""
135+
# Create new scope for comprehension
136+
old_mapping = dict(self.var_mapping)
137+
old_counter = self.var_counter
138+
139+
# Process the comprehension
140+
node = self.generic_visit(node)
141+
142+
# Restore scope
143+
self.var_mapping = old_mapping
144+
self.var_counter = old_counter
145+
return node
146+
147+
def visit_For(self, node):
148+
"""Handle for loop target variables"""
149+
# The target in a for loop is a local variable that should be normalized
150+
return self.generic_visit(node)
151+
152+
def visit_With(self, node):
153+
"""Handle with statement as variables"""
154+
return self.generic_visit(node)
155+
156+
157+
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str:
158+
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
159+
Function names, class names, and parameters are preserved.
160+
161+
Args:
162+
code: Python source code as string
163+
remove_docstrings: Whether to remove docstrings
164+
165+
Returns:
166+
Normalized code as string
167+
168+
"""
169+
try:
170+
# Parse the code
171+
tree = ast.parse(code)
172+
173+
# Remove docstrings if requested
174+
if remove_docstrings:
175+
remove_docstrings_from_ast(tree)
176+
177+
# Normalize variable names
178+
normalizer = VariableNormalizer()
179+
normalized_tree = normalizer.visit(tree)
180+
if return_ast_dump:
181+
# This is faster than unparsing etc
182+
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
183+
184+
# Fix missing locations in the AST
185+
ast.fix_missing_locations(normalized_tree)
186+
187+
# Unparse back to code
188+
return ast.unparse(normalized_tree)
189+
except SyntaxError as e:
190+
msg = f"Invalid Python syntax: {e}"
191+
raise ValueError(msg) from e
192+
193+
194+
def remove_docstrings_from_ast(node):
195+
"""Remove docstrings from AST nodes."""
196+
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
197+
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
198+
# Use our own stack-based DFS instead of ast.walk for efficiency
199+
stack = [node]
200+
while stack:
201+
current_node = stack.pop()
202+
if isinstance(current_node, node_types):
203+
# Remove docstring if it's the first stmt in body
204+
body = current_node.body
205+
if (
206+
body
207+
and isinstance(body[0], ast.Expr)
208+
and isinstance(body[0].value, ast.Constant)
209+
and isinstance(body[0].value.value, str)
210+
):
211+
current_node.body = body[1:]
212+
# Only these nodes can nest more docstring-containing nodes
213+
# Add their body elements to stack, avoiding unnecessary traversal
214+
stack.extend([child for child in body if isinstance(child, node_types)])
215+
216+
217+
def get_code_fingerprint(code: str) -> str:
218+
"""Generate a fingerprint for normalized code.
219+
220+
Args:
221+
code: Python source code
222+
223+
Returns:
224+
SHA-256 hash of normalized code
225+
226+
"""
227+
normalized = normalize_code(code)
228+
return hashlib.sha256(normalized.encode()).hexdigest()
229+
230+
231+
def are_codes_duplicate(code1: str, code2: str) -> bool:
232+
"""Check if two code segments are duplicates after normalization.
233+
234+
Args:
235+
code1: First code segment
236+
code2: Second code segment
237+
238+
Returns:
239+
True if codes are structurally identical (ignoring local variable names)
240+
241+
"""
242+
try:
243+
normalized1 = normalize_code(code1, return_ast_dump=True)
244+
normalized2 = normalize_code(code2, return_ast_dump=True)
245+
return normalized1 == normalized2
246+
except Exception:
247+
return False

0 commit comments

Comments
 (0)