Skip to content

Commit 5c0a028

Browse files
committed
Add a first version of hashing code context
1 parent 5989b26 commit 5c0a028

File tree

3 files changed

+145
-52
lines changed

3 files changed

+145
-52
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 118 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def get_code_optimization_context(
7373
remove_docstrings=False,
7474
code_context_type=CodeContextType.READ_ONLY,
7575
)
76+
hashing_code_context = extract_code_markdown_context_from_files(
77+
helpers_of_fto_dict,
78+
helpers_of_helpers_dict,
79+
project_root_path,
80+
remove_docstrings=True,
81+
code_context_type=CodeContextType.HASHING,
82+
)
7683

7784
# Handle token limits
7885
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
@@ -130,6 +137,7 @@ def get_code_optimization_context(
130137
testgen_context_code=testgen_context_code,
131138
read_writable_code=final_read_writable_code,
132139
read_only_context_code=read_only_context_code,
140+
hashing_code_context=hashing_code_context.markdown,
133141
helper_functions=helpers_of_fto_list,
134142
preexisting_objects=preexisting_objects,
135143
)
@@ -309,20 +317,21 @@ def extract_code_markdown_context_from_files(
309317
logger.debug(f"Error while getting read-only code: {e}")
310318
continue
311319
if code_context.strip():
312-
code_context_with_imports = CodeString(
313-
code=add_needed_imports_from_module(
314-
src_module_code=original_code,
315-
dst_module_code=code_context,
316-
src_path=file_path,
317-
dst_path=file_path,
318-
project_root=project_root_path,
319-
helper_functions=list(
320-
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())
320+
if code_context_type != CodeContextType.HASHING:
321+
code_context = (
322+
add_needed_imports_from_module(
323+
src_module_code=original_code,
324+
dst_module_code=code_context,
325+
src_path=file_path,
326+
dst_path=file_path,
327+
project_root=project_root_path,
328+
helper_functions=list(
329+
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())
330+
),
321331
),
322-
),
323-
file_path=file_path.relative_to(project_root_path),
324-
)
325-
code_context_markdown.code_strings.append(code_context_with_imports)
332+
)
333+
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
334+
code_context_markdown.code_strings.append(code_string_context)
326335
# Extract code from file paths containing helpers of helpers
327336
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
328337
try:
@@ -343,18 +352,19 @@ def extract_code_markdown_context_from_files(
343352
continue
344353

345354
if code_context.strip():
346-
code_context_with_imports = CodeString(
347-
code=add_needed_imports_from_module(
348-
src_module_code=original_code,
349-
dst_module_code=code_context,
350-
src_path=file_path,
351-
dst_path=file_path,
352-
project_root=project_root_path,
353-
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
354-
),
355-
file_path=file_path.relative_to(project_root_path),
356-
)
357-
code_context_markdown.code_strings.append(code_context_with_imports)
355+
if code_context_type != CodeContextType.HASHING:
356+
code_context = (
357+
add_needed_imports_from_module(
358+
src_module_code=original_code,
359+
dst_module_code=code_context,
360+
src_path=file_path,
361+
dst_path=file_path,
362+
project_root=project_root_path,
363+
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
364+
),
365+
)
366+
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
367+
code_context_markdown.code_strings.append(code_string_context)
358368
return code_context_markdown
359369

360370

@@ -492,6 +502,8 @@ def parse_code_and_prune_cst(
492502
filtered_node, found_target = prune_cst_for_testgen_code(
493503
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
494504
)
505+
elif code_context_type == CodeContextType.HASHING:
506+
filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions)
495507
else:
496508
raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102
497509

@@ -583,6 +595,87 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583595
return (node.with_changes(**updates) if updates else node), True
584596

585597

598+
def prune_cst_for_code_hashing( # noqa: PLR0911
599+
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
600+
) -> tuple[cst.CSTNode | None, bool]:
601+
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
602+
603+
Returns
604+
-------
605+
(filtered_node, found_target):
606+
filtered_node: The modified CST node or None if it should be removed.
607+
found_target: True if a target function was found in this node's subtree.
608+
609+
"""
610+
if isinstance(node, (cst.Import, cst.ImportFrom)):
611+
return None, False
612+
613+
if isinstance(node, cst.FunctionDef):
614+
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
615+
if qualified_name in target_functions:
616+
new_body = remove_docstring_from_body(node.body)
617+
return node.with_changes(body=new_body), True
618+
return None, False
619+
620+
if isinstance(node, cst.ClassDef):
621+
# Do not recurse into nested classes
622+
if prefix:
623+
return None, False
624+
# Assuming always an IndentedBlock
625+
if not isinstance(node.body, cst.IndentedBlock):
626+
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
627+
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
628+
new_body = []
629+
found_target = False
630+
631+
for stmt in node.body.body:
632+
if isinstance(stmt, cst.FunctionDef):
633+
qualified_name = f"{class_prefix}.{stmt.name.value}"
634+
if qualified_name in target_functions:
635+
new_body.append(stmt)
636+
found_target = True
637+
# If no target functions found, remove the class entirely
638+
if not new_body or not found_target:
639+
return None, False
640+
return node.with_changes(
641+
body=remove_docstring_from_body(node.body.with_changes(body=new_body))
642+
) if new_body else None, True
643+
644+
# For other nodes, we preserve them only if they contain target functions in their children.
645+
section_names = get_section_names(node)
646+
if not section_names:
647+
return node, False
648+
649+
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
650+
found_any_target = False
651+
652+
for section in section_names:
653+
original_content = getattr(node, section, None)
654+
if isinstance(original_content, (list, tuple)):
655+
new_children = []
656+
section_found_target = False
657+
for child in original_content:
658+
filtered, found_target = prune_cst_for_code_hashing(child, target_functions, prefix)
659+
if filtered:
660+
new_children.append(filtered)
661+
section_found_target |= found_target
662+
663+
if section_found_target:
664+
found_any_target = True
665+
updates[section] = new_children
666+
elif original_content is not None:
667+
filtered, found_target = prune_cst_for_code_hashing(original_content, target_functions, prefix)
668+
if found_target:
669+
found_any_target = True
670+
if filtered:
671+
updates[section] = filtered
672+
673+
if not found_any_target:
674+
return None, False
675+
676+
return (node.with_changes(**updates) if updates else node), True
677+
678+
586679
def prune_cst_for_read_only_code( # noqa: PLR0911
587680
node: cst.CSTNode,
588681
target_functions: set[str],

codeflash/models/models.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum, IntEnum
1717
from pathlib import Path
1818
from re import Pattern
19-
from typing import Annotated, Optional, cast
19+
from typing import Annotated, cast
2020

2121
from jedi.api.classes import Name
2222
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
@@ -77,10 +77,10 @@ class BestOptimization(BaseModel):
7777
candidate: OptimizedCandidate
7878
helper_functions: list[FunctionSource]
7979
runtime: int
80-
replay_performance_gain: Optional[dict[BenchmarkKey, float]] = None
80+
replay_performance_gain: dict[BenchmarkKey, float] | None = None
8181
winning_behavioral_test_results: TestResults
8282
winning_benchmarking_test_results: TestResults
83-
winning_replay_benchmarking_test_results: Optional[TestResults] = None
83+
winning_replay_benchmarking_test_results: TestResults | None = None
8484

8585

8686
@dataclass(frozen=True)
@@ -136,7 +136,7 @@ def to_dict(self) -> dict[str, list[dict[str, any]]]:
136136

137137
class CodeString(BaseModel):
138138
code: Annotated[str, AfterValidator(validate_python_code)]
139-
file_path: Optional[Path] = None
139+
file_path: Path | None = None
140140

141141

142142
class CodeStringsMarkdown(BaseModel):
@@ -157,6 +157,7 @@ class CodeOptimizationContext(BaseModel):
157157
testgen_context_code: str = ""
158158
read_writable_code: str = Field(min_length=1)
159159
read_only_context_code: str = ""
160+
hashing_code_context: str = ""
160161
helper_functions: list[FunctionSource]
161162
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]
162163

@@ -165,14 +166,15 @@ class CodeContextType(str, Enum):
165166
READ_WRITABLE = "READ_WRITABLE"
166167
READ_ONLY = "READ_ONLY"
167168
TESTGEN = "TESTGEN"
169+
HASHING = "HASHING"
168170

169171

170172
class OptimizedCandidateResult(BaseModel):
171173
max_loop_count: int
172174
best_test_runtime: int
173175
behavior_test_results: TestResults
174176
benchmarking_test_results: TestResults
175-
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
177+
replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None
176178
optimization_candidate_index: int
177179
total_candidate_timing: int
178180

@@ -192,10 +194,10 @@ class GeneratedTestsList(BaseModel):
192194
class TestFile(BaseModel):
193195
instrumented_behavior_file_path: Path
194196
benchmarking_file_path: Path = None
195-
original_file_path: Optional[Path] = None
196-
original_source: Optional[str] = None
197+
original_file_path: Path | None = None
198+
original_source: str | None = None
197199
test_type: TestType
198-
tests_in_file: Optional[list[TestsInFile]] = None
200+
tests_in_file: list[TestsInFile] | None = None
199201

200202

201203
class TestFiles(BaseModel):
@@ -238,13 +240,13 @@ def __len__(self) -> int:
238240

239241
class OptimizationSet(BaseModel):
240242
control: list[OptimizedCandidate]
241-
experiment: Optional[list[OptimizedCandidate]]
243+
experiment: list[OptimizedCandidate] | None
242244

243245

244246
@dataclass(frozen=True)
245247
class TestsInFile:
246248
test_file: Path
247-
test_class: Optional[str]
249+
test_class: str | None
248250
test_function: str
249251
test_type: TestType
250252

@@ -277,10 +279,10 @@ class FunctionParent:
277279
class OriginalCodeBaseline(BaseModel):
278280
behavioral_test_results: TestResults
279281
benchmarking_test_results: TestResults
280-
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
282+
replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None
281283
line_profile_results: dict
282284
runtime: int
283-
coverage_results: Optional[CoverageData]
285+
coverage_results: CoverageData | None
284286

285287

286288
class CoverageStatus(Enum):
@@ -299,7 +301,7 @@ class CoverageData:
299301
graph: dict[str, dict[str, Collection[object]]]
300302
code_context: CodeOptimizationContext
301303
main_func_coverage: FunctionCoverage
302-
dependent_func_coverage: Optional[FunctionCoverage]
304+
dependent_func_coverage: FunctionCoverage | None
303305
status: CoverageStatus
304306
blank_re: Pattern[str] = re.compile(r"\s*(#|$)")
305307
else_re: Pattern[str] = re.compile(r"\s*else\s*:\s*(#|$)")
@@ -407,10 +409,10 @@ def to_name(self) -> str:
407409
@dataclass(frozen=True)
408410
class InvocationId:
409411
test_module_path: str # The fully qualified name of the test module
410-
test_class_name: Optional[str] # The name of the class where the test is defined
411-
test_function_name: Optional[str] # The name of the test_function. Does not include the components of the file_name
412+
test_class_name: str | None # The name of the class where the test is defined
413+
test_function_name: str | None # The name of the test_function. Does not include the components of the file_name
412414
function_getting_tested: str
413-
iteration_id: Optional[str]
415+
iteration_id: str | None
414416

415417
# test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id
416418
def id(self) -> str:
@@ -421,7 +423,7 @@ def id(self) -> str:
421423
)
422424

423425
@staticmethod
424-
def from_str_id(string_id: str, iteration_id: Optional[str] = None) -> InvocationId:
426+
def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId:
425427
components = string_id.split(":")
426428
assert len(components) == 4
427429
second_components = components[1].split(".")
@@ -446,13 +448,13 @@ class FunctionTestInvocation:
446448
id: InvocationId # The fully qualified name of the function invocation (id)
447449
file_name: Path # The file where the test is defined
448450
did_pass: bool # Whether the test this function invocation was part of, passed or failed
449-
runtime: Optional[int] # Time in nanoseconds
451+
runtime: int | None # Time in nanoseconds
450452
test_framework: str # unittest or pytest
451453
test_type: TestType
452-
return_value: Optional[object] # The return value of the function invocation
453-
timed_out: Optional[bool]
454-
verification_type: Optional[str] = VerificationType.FUNCTION_CALL
455-
stdout: Optional[str] = None
454+
return_value: object | None # The return value of the function invocation
455+
timed_out: bool | None
456+
verification_type: str | None = VerificationType.FUNCTION_CALL
457+
stdout: str | None = None
456458

457459
@property
458460
def unique_invocation_loop_id(self) -> str:

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import ast
4-
import git
54
import concurrent.futures
65
import os
76
import subprocess
@@ -52,8 +51,6 @@
5251
from codeflash.code_utils.line_profile_utils import add_decorator_imports
5352
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
5453
from codeflash.code_utils.time_utils import humanize_runtime
55-
from codeflash.code_utils.env_utils import get_pr_number
56-
from codeflash.code_utils.git_utils import get_repo_owner_and_name
5754
from codeflash.context import code_context_extractor
5855
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
5956
from codeflash.either import Failure, Success, is_successful
@@ -265,7 +262,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
265262
# adding to control and experiment set but with same traceid
266263
best_optimization = None
267264
for _u, (candidates, exp_type) in enumerate(
268-
zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"])
265+
zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"], strict=False)
269266
):
270267
if candidates is None:
271268
continue
@@ -687,6 +684,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
687684
testgen_context_code=new_code_ctx.testgen_context_code,
688685
read_writable_code=new_code_ctx.read_writable_code,
689686
read_only_context_code=new_code_ctx.read_only_context_code,
687+
hashing_code_context=new_code_ctx.hashing_code_context,
690688
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
691689
preexisting_objects=new_code_ctx.preexisting_objects,
692690
)
@@ -1283,7 +1281,7 @@ def generate_and_instrument_tests(
12831281
test_perf_path,
12841282
)
12851283
for test_index, (test_path, test_perf_path) in enumerate(
1286-
zip(generated_test_paths, generated_perf_test_paths)
1284+
zip(generated_test_paths, generated_perf_test_paths, strict=False)
12871285
)
12881286
]
12891287

0 commit comments

Comments
 (0)