Skip to content

Commit 3eee162

Browse files
fix markdown context for formatting and more refactoring
1 parent a1c10a0 commit 3eee162

File tree

8 files changed

+77
-56
lines changed

8 files changed

+77
-56
lines changed

codeflash/api/aiservice.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def make_ai_service_request(
7373
url = f"{self.base_url}/ai{endpoint}"
7474
if method.upper() == "POST":
7575
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
76-
print(f"------------------------JSON PAYLOAD for {url}--------------------")
77-
print(json_payload)
78-
print("-------------------END OF JSON PAYLOAD--------------------")
76+
# print(f"------------------------JSON PAYLOAD for {url}--------------------")
77+
# print(json_payload)
78+
# print("-------------------END OF JSON PAYLOAD--------------------")
7979
headers = {**self.headers, "Content-Type": "application/json"}
8080
response = requests.post(url, data=json_payload, headers=headers, timeout=timeout)
8181
else:
@@ -139,7 +139,7 @@ def optimize_python_code( # noqa: D417
139139
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
140140
return [
141141
OptimizedCandidate(
142-
source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]),
142+
source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]),
143143
explanation=opt["explanation"],
144144
optimization_id=opt["optimization_id"],
145145
)
@@ -209,7 +209,7 @@ def optimize_python_code_line_profiler( # noqa: D417
209209
console.rule()
210210
return [
211211
OptimizedCandidate(
212-
source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]),
212+
source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]),
213213
explanation=opt["explanation"],
214214
optimization_id=opt["optimization_id"],
215215
)
@@ -266,7 +266,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
266266
console.rule()
267267
return [
268268
OptimizedCandidate(
269-
source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]),
269+
source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]),
270270
explanation=opt["explanation"],
271271
optimization_id=opt["optimization_id"][:-4] + "refi",
272272
)

codeflash/code_utils/formatter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def format_code(
121121

122122
if check_diff and original_code_lines > 50:
123123
# we dont' count the formatting diff for the optimized function as it should be well-formatted
124-
# TODO: This is not correct, optimized_code is not continuous, Think of a better way for doing this.
125124
original_code_without_opfunc = original_code.replace(optimized_code, "")
126125

127126
original_temp = Path(test_dir_str) / "original_temp.py"

codeflash/models/models.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Annotated, Optional, cast
2020

2121
from jedi.api.classes import Name
22-
from pydantic import AfterValidator, BaseModel, ConfigDict
22+
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
2323
from pydantic.dataclasses import dataclass
2424

2525
from codeflash.cli_cmds.console import console, logger
@@ -167,16 +167,16 @@ def get_code_block_splitter(file_path: Path) -> str:
167167

168168
class CodeStringsMarkdown(BaseModel):
169169
code_strings: list[CodeString] = []
170-
cached_code: Optional[str] = None
170+
_cache: dict = PrivateAttr(default_factory=dict)
171171

172172
@property
173173
def flat(self) -> str:
174-
if self.cached_code is not None:
175-
return self.cached_code
176-
self.cached_code = "\n".join(
174+
if self._cache.get("flat") is not None:
175+
return self._cache["flat"]
176+
self._cache["flat"] = "\n".join(
177177
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
178178
)
179-
return self.cached_code
179+
return self._cache["flat"]
180180

181181
@property
182182
def markdown(self) -> str:
@@ -188,17 +188,25 @@ def markdown(self) -> str:
188188
]
189189
)
190190

191+
def file_to_path(self) -> dict[str, str]:
192+
if self._cache.get("file_to_path") is not None:
193+
return self._cache["file_to_path"]
194+
self._cache["file_to_path"] = {
195+
str(code_string.file_path): code_string.code for code_string in self.code_strings
196+
}
197+
return self._cache["file_to_path"]
198+
191199
@staticmethod
192-
def parse_splitter_markers(code_with_markers: str) -> CodeStringsMarkdown:
200+
def parse_flattened_code(flat_code: str) -> CodeStringsMarkdown:
193201
pattern = rf"{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n"
194-
matches = list(re.finditer(pattern, code_with_markers))
202+
matches = list(re.finditer(pattern, flat_code))
195203

196204
results = CodeStringsMarkdown()
197205
for i, match in enumerate(matches):
198206
start = match.end()
199-
end = matches[i + 1].start() if i + 1 < len(matches) else len(code_with_markers)
207+
end = matches[i + 1].start() if i + 1 < len(matches) else len(flat_code)
200208
file_path = match.group(1).strip()
201-
code = code_with_markers[start:end].lstrip("\n")
209+
code = flat_code[start:end].lstrip("\n")
202210
results.code_strings.append(CodeString(code=code, file_path=Path(file_path)))
203211
return results
204212

codeflash/optimization/function_optimizer.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
LINE_SPLITTER_MARKER_PREFIX,
6666
BestOptimization,
6767
CodeOptimizationContext,
68-
CodeStringsMarkdown,
6968
GeneratedTests,
7069
GeneratedTestsList,
7170
OptimizationSet,
@@ -97,6 +96,7 @@
9796
from codeflash.either import Result
9897
from codeflash.models.models import (
9998
BenchmarkKey,
99+
CodeStringsMarkdown,
100100
CoverageData,
101101
FunctionCalledInTest,
102102
FunctionSource,
@@ -419,7 +419,7 @@ def determine_best_candidate(
419419
try:
420420
did_update = self.replace_function_and_helpers_with_optimized_code(
421421
code_context=code_context,
422-
optimized_code=candidate.source_code.flat,
422+
optimized_code=candidate.source_code,
423423
original_helper_code=original_helper_code,
424424
)
425425
if not did_update:
@@ -680,12 +680,21 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
680680
f.write(helper_code)
681681

682682
def reformat_code_and_helpers(
683-
self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_code: str
683+
self,
684+
helper_functions: list[FunctionSource],
685+
path: Path,
686+
original_code: str,
687+
optimized_context: CodeStringsMarkdown,
684688
) -> tuple[str, dict[Path, str]]:
685689
should_sort_imports = not self.args.disable_imports_sorting
686690
if should_sort_imports and isort.code(original_code) != original_code:
687691
should_sort_imports = False
688692

693+
optimized_code = ""
694+
if optimized_context is not None:
695+
file_to_code_context = optimized_context.file_to_path()
696+
optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "")
697+
689698
new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True)
690699
if should_sort_imports:
691700
new_code = sort_imports(new_code)
@@ -704,16 +713,15 @@ def reformat_code_and_helpers(
704713
return new_code, new_helper_code
705714

706715
def replace_function_and_helpers_with_optimized_code(
707-
self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str
716+
self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str
708717
) -> bool:
709718
did_update = False
710719
read_writable_functions_by_file_path = defaultdict(set)
711720
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
712721
self.function_to_optimize.qualified_name
713722
)
714723

715-
code_strings = CodeStringsMarkdown.parse_splitter_markers(optimized_code).code_strings
716-
file_to_code_context = {str(code_string.file_path): code_string.code for code_string in code_strings}
724+
file_to_code_context = optimized_code.file_to_path()
717725

718726
for helper_function in code_context.helper_functions:
719727
if helper_function.jedi_definition.type != "class":
@@ -739,7 +747,7 @@ def replace_function_and_helpers_with_optimized_code(
739747
preexisting_objects=code_context.preexisting_objects,
740748
project_root_path=self.project_root,
741749
)
742-
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
750+
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code.flat)
743751

744752
# Revert unused helper functions to their original definitions
745753
if unused_helpers:
@@ -1092,15 +1100,15 @@ def find_and_process_best_optimization(
10921100

10931101
self.replace_function_and_helpers_with_optimized_code(
10941102
code_context=code_context,
1095-
optimized_code=best_optimization.candidate.source_code.flat,
1103+
optimized_code=best_optimization.candidate.source_code,
10961104
original_helper_code=original_helper_code,
10971105
)
10981106

10991107
new_code, new_helper_code = self.reformat_code_and_helpers(
11001108
code_context.helper_functions,
11011109
explanation.file_path,
11021110
self.function_to_optimize_source_code,
1103-
optimized_code=best_optimization.candidate.source_code.flat,
1111+
optimized_context=best_optimization.candidate.source_code,
11041112
)
11051113

11061114
original_code_combined = original_helper_code.copy()
@@ -1173,8 +1181,9 @@ def process_review(
11731181
)
11741182
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
11751183
source_code=code_context.read_writable_code.flat.replace(
1176-
LINE_SPLITTER_MARKER_PREFIX, "# file: "
1177-
), # for better readability to the LLM
1184+
LINE_SPLITTER_MARKER_PREFIX,
1185+
"# file: ", # for better readability
1186+
),
11781187
dependency_code=code_context.read_only_context_code,
11791188
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
11801189
optimized_code=best_optimization.candidate.source_code.flat.replace(

tests/test_code_replacement.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
replace_functions_in_file,
1414
)
1515
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
16-
from codeflash.models.models import CodeOptimizationContext, FunctionParent, get_code_block_splitter
16+
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, get_code_block_splitter
1717
from codeflash.optimization.function_optimizer import FunctionOptimizer
1818
from codeflash.verification.verification_utils import TestConfig
1919

@@ -73,7 +73,7 @@ def sorter(arr):
7373
original_helper_code[helper_function_path] = helper_code
7474
func_optimizer.args = Args()
7575
func_optimizer.replace_function_and_helpers_with_optimized_code(
76-
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
76+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code
7777
)
7878
final_output = code_path.read_text(encoding="utf-8")
7979
assert "inconsequential_var = '123'" in final_output
@@ -1742,7 +1742,7 @@ def new_function2(value):
17421742
original_helper_code[helper_function_path] = helper_code
17431743
func_optimizer.args = Args()
17441744
func_optimizer.replace_function_and_helpers_with_optimized_code(
1745-
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
1745+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code
17461746
)
17471747
new_code = code_path.read_text(encoding="utf-8")
17481748
code_path.unlink(missing_ok=True)
@@ -1819,7 +1819,7 @@ def new_function2(value):
18191819
original_helper_code[helper_function_path] = helper_code
18201820
func_optimizer.args = Args()
18211821
func_optimizer.replace_function_and_helpers_with_optimized_code(
1822-
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
1822+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code
18231823
)
18241824
new_code = code_path.read_text(encoding="utf-8")
18251825
code_path.unlink(missing_ok=True)
@@ -1897,7 +1897,7 @@ def new_function2(value):
18971897
original_helper_code[helper_function_path] = helper_code
18981898
func_optimizer.args = Args()
18991899
func_optimizer.replace_function_and_helpers_with_optimized_code(
1900-
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
1900+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code
19011901
)
19021902
new_code = code_path.read_text(encoding="utf-8")
19031903
code_path.unlink(missing_ok=True)
@@ -1974,7 +1974,7 @@ def new_function2(value):
19741974
original_helper_code[helper_function_path] = helper_code
19751975
func_optimizer.args = Args()
19761976
func_optimizer.replace_function_and_helpers_with_optimized_code(
1977-
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
1977+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code
19781978
)
19791979
new_code = code_path.read_text(encoding="utf-8")
19801980
code_path.unlink(missing_ok=True)
@@ -2052,7 +2052,7 @@ def new_function2(value):
20522052
original_helper_code[helper_function_path] = helper_code
20532053
func_optimizer.args = Args()
20542054
func_optimizer.replace_function_and_helpers_with_optimized_code(
2055-
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
2055+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code
20562056
)
20572057
new_code = code_path.read_text(encoding="utf-8")
20582058
code_path.unlink(missing_ok=True)
@@ -2141,7 +2141,7 @@ def new_function2(value):
21412141
original_helper_code[helper_function_path] = helper_code
21422142
func_optimizer.args = Args()
21432143
func_optimizer.replace_function_and_helpers_with_optimized_code(
2144-
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
2144+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code
21452145
)
21462146
new_code = code_path.read_text(encoding="utf-8")
21472147
code_path.unlink(missing_ok=True)

tests/test_formatter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from codeflash.code_utils.formatter import format_code, sort_imports
1111

1212
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
13+
from codeflash.models.models import CodeString, CodeStringsMarkdown
1314
from codeflash.optimization.function_optimizer import FunctionOptimizer
1415
from codeflash.verification.verification_utils import TestConfig
1516

@@ -263,7 +264,12 @@ def _run_formatting_test(source_code: str, should_content_change: bool, expected
263264
helper_functions=[],
264265
path=target_path,
265266
original_code=optimizer.function_to_optimize_source_code,
266-
optimized_code=optimized_function,
267+
optimized_context=CodeStringsMarkdown(code_strings=[
268+
CodeString(
269+
code=optimized_function,
270+
file_path=target_path.relative_to(test_dir)
271+
)
272+
]),
267273
)
268274

269275
content = target_path.read_text(encoding="utf8")
@@ -796,7 +802,7 @@ def _is_valid(self, item):
796802
return isinstance(item, dict) and "id" in item
797803
'''
798804

799-
optimization_function = """ def process(self,data):
805+
optimization_function = """def process(self,data):
800806
'''Single quote docstring with formatting issues.'''
801807
return{'result':[item for item in data if self._is_valid(item)]}"""
802808
_run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected)

tests/test_multi_file_code_replacement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pathlib import Path
22
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
3-
from codeflash.models.models import CodeOptimizationContext, get_code_block_splitter
3+
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, get_code_block_splitter
44
from codeflash.optimization.function_optimizer import FunctionOptimizer
55
from codeflash.verification.verification_utils import TestConfig
66

@@ -117,7 +117,7 @@ def _get_string_usage(text: str) -> Usage:
117117

118118
func_optimizer.args = Args()
119119
func_optimizer.replace_function_and_helpers_with_optimized_code(
120-
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
120+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code
121121
)
122122
new_code = main_file.read_text(encoding="utf-8")
123123
new_helper_code = helper_file.read_text(encoding="utf-8")

0 commit comments

Comments
 (0)