Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions codeflash/code_utils/coverage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
"""Extract the single dependent function from the code context excluding the main function."""
ast_tree = ast.parse(code_context.testgen_context_code)

dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}
dependent_functions = set()
for code_string in code_context.testgen_context.code_strings:
ast_tree = ast.parse(code_string.code)
dependent_functions.update({node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)})

if main_function in dependent_functions:
dependent_functions.discard(main_function)
Expand Down
18 changes: 9 additions & 9 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,32 @@ def get_code_optimization_context(
read_only_context_code = ""

# Extract code context for testgen
testgen_code_markdown = extract_code_string_context_from_files(
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
testgen_code_markdown = extract_code_string_context_from_files(
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
code_hash_context = hashing_code_context.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()

return CodeOptimizationContext(
testgen_context_code=testgen_context_code,
testgen_context=testgen_context,
read_writable_code=final_read_writable_code,
read_only_context_code=read_only_context_code,
hashing_code_context=code_hash_context,
Expand Down
2 changes: 1 addition & 1 deletion codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:


class CodeOptimizationContext(BaseModel):
testgen_context_code: str = ""
testgen_context: CodeStringsMarkdown
read_writable_code: CodeStringsMarkdown
read_only_context_code: str = ""
hashing_code_context: str = ""
Expand Down
9 changes: 4 additions & 5 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def generate_and_instrument_tests(
revert_to_print=bool(get_pr_number()),
):
generated_results = self.generate_tests_and_optimizations(
testgen_context_code=code_context.testgen_context_code,
testgen_context=code_context.testgen_context,
read_writable_code=code_context.read_writable_code,
read_only_context_code=code_context.read_only_context_code,
helper_functions=code_context.helper_functions,
Expand Down Expand Up @@ -345,7 +345,6 @@ def generate_and_instrument_tests(
logger.info(f"Generated test {i + 1}/{count_tests}:")
code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py")
if concolic_test_str:
# no concolic tests in lsp mode
logger.info(f"Generated test {count_tests}/{count_tests}:")
code_print(concolic_test_str)

Expand Down Expand Up @@ -946,7 +945,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:

return Success(
CodeOptimizationContext(
testgen_context_code=new_code_ctx.testgen_context_code,
testgen_context=new_code_ctx.testgen_context,
read_writable_code=new_code_ctx.read_writable_code,
read_only_context_code=new_code_ctx.read_only_context_code,
hashing_code_context=new_code_ctx.hashing_code_context,
Expand Down Expand Up @@ -1053,7 +1052,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio

def generate_tests_and_optimizations(
self,
testgen_context_code: str,
testgen_context: CodeStringsMarkdown,
read_writable_code: CodeStringsMarkdown,
read_only_context_code: str,
helper_functions: list[FunctionSource],
Expand All @@ -1067,7 +1066,7 @@ def generate_tests_and_optimizations(
# Submit the test generation task as future
future_tests = self.submit_test_generation_tasks(
self.executor,
testgen_context_code,
testgen_context.markdown,
[definition.fully_qualified_name for definition in helper_functions],
generated_test_paths,
generated_perf_test_paths,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,8 @@ def main_method(self):


def test_code_replacement10() -> None:
get_code_output = """from __future__ import annotations
get_code_output = """# file: test_code_replacement.py
from __future__ import annotations

class HelperClass:
def __init__(self, name):
Expand Down Expand Up @@ -827,7 +828,7 @@ def main_method(self):
)
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
code_context = func_optimizer.get_code_optimization_context().unwrap()
assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip()
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()


def test_code_replacement11() -> None:
Expand Down
10 changes: 6 additions & 4 deletions tests/test_function_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None:
)
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
assert (
code_context.testgen_context_code
== """from collections import defaultdict
code_context.testgen_context.flat
== """# file: test_function_dependencies.py
from collections import defaultdict

class Graph:
def __init__(self, vertices):
Expand Down Expand Up @@ -220,8 +221,9 @@ def test_recursive_function_context() -> None:
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
assert (
code_context.testgen_context_code
== """class C:
code_context.testgen_context.flat
== """# file: test_function_dependencies.py
class C:
def calculate_something_3(self, num):
return num + 1

Expand Down
19 changes: 11 additions & 8 deletions tests/test_get_helper_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
code_context = ctx_result.unwrap()
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert (
code_context.testgen_context_code
== f'''_P = ParamSpec("_P")
code_context.testgen_context.flat
== f'''# file: {file_path.relative_to(project_root_path)}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
Expand Down Expand Up @@ -394,10 +395,11 @@ def test_bubble_sort_deps() -> None:
function_to_optimize = FunctionToOptimize(
function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None
)
project_root = file_path.parent.parent.resolve()
test_config = TestConfig(
tests_root=str(file_path.parent / "tests"),
tests_project_rootdir=file_path.parent.resolve(),
project_root_path=file_path.parent.parent.resolve(),
project_root_path=project_root,
test_framework="pytest",
pytest_cmd="pytest",
)
Expand All @@ -409,19 +411,20 @@ def test_bubble_sort_deps() -> None:
pytest.fail()
code_context = ctx_result.unwrap()
assert (
code_context.testgen_context_code
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap

code_context.testgen_context.flat
== f"""# file: code_to_optimize/bubble_sort_dep1_helper.py
def dep1_comparer(arr, j: int) -> bool:
return arr[j] > arr[j + 1]

# file: code_to_optimize/bubble_sort_dep2_swap.py
def dep2_swap(arr, j):
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp


# file: code_to_optimize/bubble_sort_deps.py
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap

def sorter_deps(arr):
for i in range(len(arr)):
Expand Down
Loading