Skip to content

Commit 5981d75

Browse files
use markdown context for the testgen
1 parent 32e85ee commit 5981d75

File tree

7 files changed

+23
-23
lines changed

7 files changed

+23
-23
lines changed

codeflash/code_utils/coverage_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212

1313
def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
1414
"""Extract the single dependent function from the code context excluding the main function."""
15-
ast_tree = ast.parse(code_context.testgen_context_code)
16-
17-
dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}
15+
dependent_functions = set()
16+
for code_string in code_context.testgen_context.code_strings:
17+
ast_tree = ast.parse(code_string.code)
18+
dependent_functions.update({node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)})
1819

1920
if main_function in dependent_functions:
2021
dependent_functions.discard(main_function)

codeflash/context/code_context_extractor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,32 +114,32 @@ def get_code_optimization_context(
114114
read_only_context_code = ""
115115

116116
# Extract code context for testgen
117-
testgen_code_markdown = extract_code_string_context_from_files(
117+
testgen_context = extract_code_markdown_context_from_files(
118118
helpers_of_fto_dict,
119119
helpers_of_helpers_dict,
120120
project_root_path,
121121
remove_docstrings=False,
122122
code_context_type=CodeContextType.TESTGEN,
123123
)
124-
testgen_context_code = testgen_code_markdown.code
125-
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
126-
if testgen_context_code_tokens > testgen_token_limit:
127-
testgen_code_markdown = extract_code_string_context_from_files(
124+
testgen_markdown_code = testgen_context.markdown
125+
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
126+
if testgen_code_token_length > testgen_token_limit:
127+
testgen_context = extract_code_markdown_context_from_files(
128128
helpers_of_fto_dict,
129129
helpers_of_helpers_dict,
130130
project_root_path,
131131
remove_docstrings=True,
132132
code_context_type=CodeContextType.TESTGEN,
133133
)
134-
testgen_context_code = testgen_code_markdown.code
135-
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
136-
if testgen_context_code_tokens > testgen_token_limit:
134+
testgen_markdown_code = testgen_context.markdown
135+
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
136+
if testgen_code_token_length > testgen_token_limit:
137137
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
138138
code_hash_context = hashing_code_context.markdown
139139
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
140140

141141
return CodeOptimizationContext(
142-
testgen_context_code=testgen_context_code,
142+
testgen_context=testgen_context,
143143
read_writable_code=final_read_writable_code,
144144
read_only_context_code=read_only_context_code,
145145
hashing_code_context=code_hash_context,

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
253253

254254

255255
class CodeOptimizationContext(BaseModel):
256-
testgen_context_code: str = ""
256+
testgen_context: CodeStringsMarkdown
257257
read_writable_code: CodeStringsMarkdown
258258
read_only_context_code: str = ""
259259
hashing_code_context: str = ""

codeflash/optimization/function_optimizer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def generate_and_instrument_tests(
309309
revert_to_print=bool(get_pr_number()),
310310
):
311311
generated_results = self.generate_tests_and_optimizations(
312-
testgen_context_code=code_context.testgen_context_code,
312+
testgen_context=code_context.testgen_context,
313313
read_writable_code=code_context.read_writable_code,
314314
read_only_context_code=code_context.read_only_context_code,
315315
helper_functions=code_context.helper_functions,
@@ -345,7 +345,6 @@ def generate_and_instrument_tests(
345345
logger.info(f"Generated test {i + 1}/{count_tests}:")
346346
code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py")
347347
if concolic_test_str:
348-
# no concolic tests in lsp mode
349348
logger.info(f"Generated test {count_tests}/{count_tests}:")
350349
code_print(concolic_test_str)
351350

@@ -946,7 +945,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
946945

947946
return Success(
948947
CodeOptimizationContext(
949-
testgen_context_code=new_code_ctx.testgen_context_code,
948+
testgen_context=new_code_ctx.testgen_context,
950949
read_writable_code=new_code_ctx.read_writable_code,
951950
read_only_context_code=new_code_ctx.read_only_context_code,
952951
hashing_code_context=new_code_ctx.hashing_code_context,
@@ -1053,7 +1052,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10531052

10541053
def generate_tests_and_optimizations(
10551054
self,
1056-
testgen_context_code: str,
1055+
testgen_context: CodeStringsMarkdown,
10571056
read_writable_code: CodeStringsMarkdown,
10581057
read_only_context_code: str,
10591058
helper_functions: list[FunctionSource],
@@ -1067,7 +1066,7 @@ def generate_tests_and_optimizations(
10671066
# Submit the test generation task as future
10681067
future_tests = self.submit_test_generation_tasks(
10691068
self.executor,
1070-
testgen_context_code,
1069+
testgen_context.markdown,
10711070
[definition.fully_qualified_name for definition in helper_functions],
10721071
generated_test_paths,
10731072
generated_perf_test_paths,

tests/test_code_replacement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def main_method(self):
827827
)
828828
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
829829
code_context = func_optimizer.get_code_optimization_context().unwrap()
830-
assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip()
830+
assert code_context.testgen_context.rstrip() == get_code_output.rstrip()
831831

832832

833833
def test_code_replacement11() -> None:

tests/test_function_dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_class_method_dependencies() -> None:
160160
)
161161
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
162162
assert (
163-
code_context.testgen_context_code
163+
code_context.testgen_context
164164
== """from collections import defaultdict
165165
166166
class Graph:
@@ -220,7 +220,7 @@ def test_recursive_function_context() -> None:
220220
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
221221
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
222222
assert (
223-
code_context.testgen_context_code
223+
code_context.testgen_context
224224
== """class C:
225225
def calculate_something_3(self, num):
226226
return num + 1

tests/test_get_helper_code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
241241
code_context = ctx_result.unwrap()
242242
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
243243
assert (
244-
code_context.testgen_context_code
244+
code_context.testgen_context
245245
== f'''_P = ParamSpec("_P")
246246
_KEY_T = TypeVar("_KEY_T")
247247
_STORE_T = TypeVar("_STORE_T")
@@ -409,7 +409,7 @@ def test_bubble_sort_deps() -> None:
409409
pytest.fail()
410410
code_context = ctx_result.unwrap()
411411
assert (
412-
code_context.testgen_context_code
412+
code_context.testgen_context
413413
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
414414
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
415415

0 commit comments

Comments
 (0)