Skip to content

Commit 302e919

Browse files
committed
Modified testgen context to be a codestring instead of markdown
1 parent 37769f9 commit 302e919

File tree

4 files changed

+100
-53
lines changed

4 files changed

+100
-53
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_code_optimization_context(
3939
)
4040

4141
# Extract code context for optimization
42-
final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code
42+
final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto, helpers_of_fto_fqn, {}, {}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code
4343
read_only_code_markdown = extract_code_markdown_context_from_files(
4444
helpers_of_fto,
4545
helpers_of_fto_fqn,
@@ -85,7 +85,7 @@ def get_code_optimization_context(
8585
logger.debug("Code context has exceeded token limit, removing read-only code")
8686
read_only_context_code = ""
8787
# Extract code context for testgen
88-
testgen_code_markdown = extract_code_markdown_context_from_files(
88+
testgen_code_markdown = extract_code_string_context_from_files(
8989
helpers_of_fto,
9090
helpers_of_fto_fqn,
9191
helpers_of_helpers,
@@ -94,10 +94,10 @@ def get_code_optimization_context(
9494
remove_docstrings=False,
9595
code_context_type=CodeContextType.TESTGEN,
9696
)
97-
testgen_context_code = testgen_code_markdown.markdown
97+
testgen_context_code = testgen_code_markdown.code
9898
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
9999
if testgen_context_code_tokens > testgen_token_limit:
100-
testgen_code_markdown = extract_code_markdown_context_from_files(
100+
testgen_code_markdown = extract_code_string_context_from_files(
101101
helpers_of_fto,
102102
helpers_of_fto_fqn,
103103
helpers_of_helpers,
@@ -106,64 +106,117 @@ def get_code_optimization_context(
106106
remove_docstrings=True,
107107
code_context_type=CodeContextType.TESTGEN,
108108
)
109-
testgen_context_code = testgen_code_markdown.markdown
109+
testgen_context_code = testgen_code_markdown.code
110110
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
111111
if testgen_context_code_tokens > testgen_token_limit:
112112
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
113113

114114
return CodeOptimizationContext(
115115
testgen_context_code = testgen_context_code,
116-
read_writable_code=CodeString(code=final_read_writable_code).code,
116+
read_writable_code=final_read_writable_code,
117117
read_only_context_code=read_only_context_code,
118118
helper_functions=helpers_of_fto_obj_list,
119119
preexisting_objects=preexisting_objects,
120120
)
121121

122-
123122
def extract_code_string_context_from_files(
124-
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
123+
helpers_of_fto: dict[Path, set[str]],
124+
helpers_of_fto_fqn: dict[Path, set[str]],
125+
helpers_of_helpers: dict[Path, set[str]],
126+
helpers_of_helpers_fqn: dict[Path, set[str]],
127+
project_root_path: Path,
128+
remove_docstrings: bool = False,
129+
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
125130
) -> CodeString:
126-
"""Extract read-writable code context from files containing target functions and their helpers.
131+
"""Extract code context from files containing target functions and their helpers, formatting them as markdown.
127132
128-
This function iterates through each file path that contains functions to optimize (fto) or
129-
their first-degree helpers, reads the original code, extracts relevant parts using CST parsing,
130-
and adds necessary imports from the original modules.
133+
This function processes two sets of files:
134+
1. Files containing the function to optimize (fto) and their first-degree helpers
135+
2. Files containing only helpers of helpers (with no overlap with the first set)
136+
137+
For each file, it extracts relevant code based on the specified context type, adds necessary
138+
imports, and combines them
131139
132140
Args:
133-
helpers_of_fto: Dictionary mapping file paths to sets of qualified function names
134-
helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions
135-
project_root_path: Root path of the project for resolving relative imports
141+
helpers_of_fto: Dictionary mapping file paths to sets of function names to be optimized
142+
helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions to be optimized
143+
helpers_of_helpers: Dictionary mapping file paths to sets of helper function names
144+
helpers_of_helpers_fqn: Dictionary mapping file paths to sets of fully qualified names of helper functions
145+
project_root_path: Root path of the project
146+
remove_docstrings: Whether to remove docstrings from the extracted code
147+
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
136148
137149
Returns:
138-
CodeString object containing the consolidated read-writable code with all necessary
139-
imports for the target functions and their helpers
150+
CodeString containing the extracted code context with necessary imports
140151
141152
"""
142-
final_read_writable_code = ""
143-
# Extract code from file paths that contain fto and first degree helpers
153+
# Rearrange to remove overlaps, so we only access each file path once
154+
helpers_of_helpers_no_overlap = defaultdict(set)
155+
helpers_of_helpers_no_overlap_fqn = defaultdict(set)
156+
for file_path in helpers_of_helpers:
157+
if file_path in helpers_of_fto:
158+
# Remove duplicates, in case a helper of helper is also a helper of fto
159+
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
160+
helpers_of_helpers_fqn[file_path] -= helpers_of_fto_fqn[file_path]
161+
else:
162+
helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path]
163+
helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path]
164+
165+
final_code_string_context = ""
166+
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
144167
for file_path, qualified_function_names in helpers_of_fto.items():
145168
try:
146169
original_code = file_path.read_text("utf8")
147170
except Exception as e:
148171
logger.exception(f"Error while parsing {file_path}: {e}")
149172
continue
150173
try:
151-
read_writable_code = parse_code_and_prune_cst(original_code, CodeContextType.READ_WRITABLE, qualified_function_names)
174+
code_context = parse_code_and_prune_cst(
175+
original_code, code_context_type, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings
176+
)
177+
152178
except ValueError as e:
153-
logger.debug(f"Error while getting read-writable code: {e}")
179+
logger.debug(f"Error while getting read-only code: {e}")
180+
continue
181+
if code_context.strip():
182+
final_code_string_context += f"\n{code_context}"
183+
final_code_string_context = add_needed_imports_from_module(
184+
src_module_code=original_code,
185+
dst_module_code=final_code_string_context,
186+
src_path=file_path,
187+
dst_path=file_path,
188+
project_root=project_root_path,
189+
helper_functions_fqn=helpers_of_fto_fqn.get(file_path, set()) | helpers_of_helpers_fqn.get(file_path, set()),
190+
)
191+
if code_context_type == CodeContextType.READ_WRITABLE:
192+
return CodeString(code=final_code_string_context)
193+
# Extract code from file paths containing helpers of helpers
194+
for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items():
195+
try:
196+
original_code = file_path.read_text("utf8")
197+
except Exception as e:
198+
logger.exception(f"Error while parsing {file_path}: {e}")
199+
continue
200+
try:
201+
code_context = parse_code_and_prune_cst(
202+
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
203+
)
204+
except ValueError as e:
205+
logger.debug(f"Error while getting read-only code: {e}")
154206
continue
155207

156-
if read_writable_code:
157-
final_read_writable_code += f"\n{read_writable_code}"
158-
final_read_writable_code = add_needed_imports_from_module(
159-
src_module_code=original_code,
160-
dst_module_code=final_read_writable_code,
161-
src_path=file_path,
162-
dst_path=file_path,
163-
project_root=project_root_path,
164-
helper_functions_fqn=helpers_of_fto_fqn[file_path],
208+
if code_context.strip():
209+
final_code_string_context += f"\n{code_context}"
210+
final_code_string_context = add_needed_imports_from_module(
211+
src_module_code=original_code,
212+
dst_module_code=final_code_string_context,
213+
src_path=file_path,
214+
dst_path=file_path,
215+
project_root=project_root_path,
216+
helper_functions_fqn=helpers_of_helpers_no_overlap_fqn.get(file_path, set()),
165217
)
166-
return CodeString(code=final_read_writable_code)
218+
return CodeString(code=final_code_string_context)
219+
167220
def extract_code_markdown_context_from_files(
168221
helpers_of_fto: dict[Path, set[str]],
169222
helpers_of_fto_fqn: dict[Path, set[str]],

tests/test_code_replacement.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,7 @@ def main_method(self):
747747

748748

749749
def test_code_replacement10() -> None:
750-
get_code_output = """```python:test_code_replacement.py
751-
from __future__ import annotations
750+
get_code_output = """from __future__ import annotations
752751
import os
753752
754753
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
@@ -768,7 +767,7 @@ def __init__(self, name):
768767
769768
def main_method(self):
770769
return HelperClass(self.name).helper_method()
771-
```"""
770+
"""
772771
file_path = Path(__file__).resolve()
773772
func_top_optimize = FunctionToOptimize(
774773
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]

tests/test_function_dependencies.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,7 @@ def test_class_method_dependencies() -> None:
162162
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
163163
assert (
164164
code_context.testgen_context_code
165-
== """```python:test_function_dependencies.py
166-
from collections import defaultdict
165+
== """from collections import defaultdict
167166
168167
class Graph:
169168
def __init__(self, vertices):
@@ -188,8 +187,7 @@ def topologicalSort(self):
188187
self.topologicalSortUtil(i, visited, stack)
189188
190189
# Print contents of stack
191-
return stack
192-
```"""
190+
return stack"""
193191
)
194192

195193
def test_recursive_function_context() -> None:
@@ -224,15 +222,13 @@ def test_recursive_function_context() -> None:
224222
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
225223
assert (
226224
code_context.testgen_context_code
227-
== """```python:test_function_dependencies.py
228-
class C:
225+
== """class C:
229226
def calculate_something_3(self, num):
230227
return num + 1
231228
232229
def recursive(self, num):
233230
if num == 0:
234231
return 0
235232
num_1 = self.calculate_something_3(num)
236-
return self.recursive(num) + num_1
237-
```"""
233+
return self.recursive(num) + num_1"""
238234
)

tests/test_get_helper_code.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
242242
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
243243
assert (
244244
code_context.testgen_context_code
245-
== f'''```python:{file_path.relative_to(project_root_path)}
246-
_P = ParamSpec("_P")
245+
== f'''_P = ParamSpec("_P")
247246
_KEY_T = TypeVar("_KEY_T")
248247
_STORE_T = TypeVar("_STORE_T")
249248
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
@@ -385,7 +384,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
385384
kwargs=kwargs,
386385
lifespan=self.__duration__,
387386
)
388-
```'''
387+
'''
389388
)
390389

391390

@@ -411,27 +410,27 @@ def test_bubble_sort_deps() -> None:
411410
code_context = ctx_result.unwrap()
412411
assert (
413412
code_context.testgen_context_code
414-
== """```python:code_to_optimize/bubble_sort_dep1_helper.py
413+
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
414+
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
415+
415416
def dep1_comparer(arr, j: int) -> bool:
416417
return arr[j] > arr[j + 1]
417-
```
418-
```python:code_to_optimize/bubble_sort_dep2_swap.py
418+
419419
def dep2_swap(arr, j):
420420
temp = arr[j]
421421
arr[j] = arr[j + 1]
422422
arr[j + 1] = temp
423-
```
424-
```python:code_to_optimize/bubble_sort_deps.py
425-
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
426-
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
423+
424+
427425
428426
def sorter_deps(arr):
429427
for i in range(len(arr)):
430428
for j in range(len(arr) - 1):
431429
if dep1_comparer(arr, j):
432430
dep2_swap(arr, j)
433431
return arr
434-
```"""
432+
433+
"""
435434
)
436435
assert len(code_context.helper_functions) == 2
437436
assert (

0 commit comments

Comments
 (0)