Skip to content

Commit 52715ba

Browse files
committed
Merge remote-tracking branch 'origin/main' into line-profiler
2 parents 4247780 + fb27a4c commit 52715ba

File tree

15 files changed

+523
-188
lines changed

15 files changed

+523
-188
lines changed

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from concurrent.futures import ThreadPoolExecutor
12
def funcA(number):
23
k = 0
34
for i in range(number * 100):
@@ -8,7 +9,14 @@ def funcA(number):
89
# Use a generator expression directly in join for more efficiency
910
return " ".join(str(i) for i in range(number))
1011

12+
def test_threadpool() -> None:
13+
pool = ThreadPoolExecutor(max_workers=3)
14+
args = list(range(10, 31, 10))
15+
result = pool.map(funcA, args)
16+
17+
for r in result:
18+
print(r)
19+
1120

1221
if __name__ == "__main__":
13-
for i in range(10, 31, 10):
14-
funcA(10)
22+
test_threadpool()

codeflash/cli_cmds/cmd_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
595595
if formatter in ["black", "ruff"]:
596596
try:
597597
subprocess.run([formatter], capture_output=True, check=False)
598-
except FileNotFoundError:
598+
except (FileNotFoundError, NotADirectoryError):
599599
click.echo(f"⚠️ Formatter not found: {formatter}, please ensure it is installed")
600600
codeflash_section["formatter-cmds"] = formatter_cmds
601601
# Add the 'codeflash' section, ensuring 'tool' section exists

codeflash/code_utils/code_extractor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,20 +235,20 @@ def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str |
235235
return edited_code, contextual_dunder_methods
236236

237237

238-
def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
239-
"""Find all preexisting functions, classes or class methods in the source code"""
240-
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
238+
def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]:
239+
"""Find all preexisting functions, classes or class methods in the source code."""
240+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set()
241241
try:
242242
module_node: ast.Module = ast.parse(source_code)
243243
except SyntaxError:
244244
logger.exception("find_preexisting_objects - Syntax error while parsing code")
245245
return preexisting_objects
246246
for node in module_node.body:
247247
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
248-
preexisting_objects.append((node.name, []))
248+
preexisting_objects.add((node.name, ()))
249249
elif isinstance(node, ast.ClassDef):
250-
preexisting_objects.append((node.name, []))
250+
preexisting_objects.add((node.name, ()))
251251
for cnode in node.body:
252252
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
253-
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
253+
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
254254
return preexisting_objects

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
3939

4040
def __init__(
4141
self,
42-
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
42+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
4343
function_names: set[tuple[str | None, str]] | None = None,
4444
) -> None:
4545
super().__init__()
46-
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else []
46+
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()
4747

4848
self.function_names = function_names # set of (class_name, function_name)
4949
self.modified_functions: dict[
@@ -61,7 +61,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
6161
self.modified_init_functions[self.current_class] = node
6262
elif (
6363
self.preexisting_objects
64-
and (node.name.value, []) not in self.preexisting_objects
64+
and (node.name.value, ()) not in self.preexisting_objects
6565
and self.current_class is None
6666
):
6767
self.new_functions.append(node)
@@ -72,7 +72,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
7272
return False # If already in a class, do not recurse deeper
7373
self.current_class = node.name.value
7474

75-
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
75+
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
7676
for child_node in node.body.body:
7777
if (
7878
self.preexisting_objects
@@ -160,7 +160,7 @@ def replace_functions_in_file(
160160
source_code: str,
161161
original_function_names: list[str],
162162
optimized_code: str,
163-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
163+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
164164
) -> str:
165165
parsed_function_names = []
166166
for original_function_name in original_function_names:
@@ -196,7 +196,7 @@ def replace_functions_and_add_imports(
196196
function_names: list[str],
197197
optimized_code: str,
198198
module_abspath: Path,
199-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
199+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
200200
project_root_path: Path,
201201
) -> str:
202202
return add_needed_imports_from_module(
@@ -212,7 +212,7 @@ def replace_function_definitions_in_module(
212212
function_names: list[str],
213213
optimized_code: str,
214214
module_abspath: Path,
215-
preexisting_objects: list[tuple[str, list[FunctionParent]]],
215+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
216216
project_root_path: Path,
217217
) -> bool:
218218
source_code: str = module_abspath.read_text(encoding="utf8")

codeflash/context/code_context_extractor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,25 @@ def get_code_optimization_context(
3030
) -> CodeOptimizationContext:
3131
# Get FunctionSource representation of helpers of FTO
3232
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi({function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path)
33+
34+
# Add function to optimize into helpers of FTO dict, as they'll be processed together
35+
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
36+
helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source)
37+
38+
# Format data to search for helpers of helpers using get_function_sources_from_jedi
3339
helpers_of_fto_qualified_names_dict = {
3440
file_path: {source.qualified_name for source in sources}
3541
for file_path, sources in helpers_of_fto_dict.items()
3642
}
3743

44+
# __init__ functions are automatically considered as helpers of FTO, so we add them to the dict (regardless of whether they exist)
45+
# This helps us to search for helpers of __init__ functions of classes that contain helpers of FTO
46+
for qualified_names in helpers_of_fto_qualified_names_dict.values():
47+
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if '.' in qn})
48+
3849
# Get FunctionSource representation of helpers of helpers of FTO
3950
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(helpers_of_fto_qualified_names_dict, project_root_path)
4051

41-
# Add function to optimize into helpers of FTO dict, as they'll be processed together
42-
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
43-
helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source)
44-
4552
# Extract code context for optimization
4653
final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto_dict,{}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code
4754
read_only_code_markdown = extract_code_markdown_context_from_files(
@@ -58,8 +65,8 @@ def get_code_optimization_context(
5865
if final_read_writable_tokens > optim_token_limit:
5966
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
6067

61-
# Setup preexisting objects for code replacer TODO: should remove duplicates
62-
preexisting_objects = list(
68+
# Setup preexisting objects for code replacer
69+
preexisting_objects = set(
6370
chain(
6471
find_preexisting_objects(final_read_writable_code),
6572
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),

codeflash/discovery/discover_unit_tests.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,19 @@ def process_test_files(
186186
jedi_project = jedi.Project(path=project_root_path)
187187

188188
for test_file, functions in file_to_test_map.items():
189-
script = jedi.Script(path=test_file, project=jedi_project)
190-
test_functions = set()
191-
192-
all_names = script.get_names(all_scopes=True, references=True)
193-
all_defs = script.get_names(all_scopes=True, definitions=True)
194-
all_names_top = script.get_names(all_scopes=True)
195-
196-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
197-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
189+
try:
190+
script = jedi.Script(path=test_file, project=jedi_project)
191+
test_functions = set()
192+
193+
all_names = script.get_names(all_scopes=True, references=True)
194+
all_defs = script.get_names(all_scopes=True, definitions=True)
195+
all_names_top = script.get_names(all_scopes=True)
196+
197+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
198+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
199+
except Exception as e:
200+
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
201+
continue
198202

199203
if test_framework == "pytest":
200204
for function in functions:

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class CodeOptimizationContext(BaseModel):
9999
read_writable_code: str = Field(min_length=1)
100100
read_only_context_code: str = ""
101101
helper_functions: list[FunctionSource]
102-
preexisting_objects: list[tuple[str, list[FunctionParent]]]
102+
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]]
103103

104104
class CodeContextType(str, Enum):
105105
READ_WRITABLE = "READ_WRITABLE"

0 commit comments

Comments
 (0)