Skip to content

Commit 84324f8

Browse files
markdown multi context
Signed-off-by: mohammed <[email protected]>
1 parent 4edea33 commit 84324f8

File tree

7 files changed

+43
-29
lines changed

7 files changed

+43
-29
lines changed

codeflash/api/aiservice.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1515
from codeflash.models.ExperimentMetadata import ExperimentMetadata
16-
from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate
16+
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
1717
from codeflash.telemetry.posthog_cf import ph
1818
from codeflash.version import __version__ as codeflash_version
1919

@@ -73,6 +73,9 @@ def make_ai_service_request(
7373
url = f"{self.base_url}/ai{endpoint}"
7474
if method.upper() == "POST":
7575
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
76+
print(f"------------------------JSON PAYLOAD for {url}--------------------")
77+
print(json_payload)
78+
print("-------------------END OF JSON PAYLOAD--------------------")
7679
headers = {**self.headers, "Content-Type": "application/json"}
7780
response = requests.post(url, data=json_payload, headers=headers, timeout=timeout)
7881
else:
@@ -136,7 +139,7 @@ def optimize_python_code( # noqa: D417
136139
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
137140
return [
138141
OptimizedCandidate(
139-
source_code=opt["source_code"],
142+
source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]),
140143
explanation=opt["explanation"],
141144
optimization_id=opt["optimization_id"],
142145
)
@@ -206,7 +209,7 @@ def optimize_python_code_line_profiler( # noqa: D417
206209
console.rule()
207210
return [
208211
OptimizedCandidate(
209-
source_code=opt["source_code"],
212+
source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]),
210213
explanation=opt["explanation"],
211214
optimization_id=opt["optimization_id"],
212215
)
@@ -263,7 +266,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
263266
console.rule()
264267
return [
265268
OptimizedCandidate(
266-
source_code=opt["source_code"],
269+
source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]),
267270
explanation=opt["explanation"],
268271
optimization_id=opt["optimization_id"][:-4] + "refi",
269272
)

codeflash/code_utils/code_replacer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections import defaultdict
55
from functools import lru_cache
66
from typing import TYPE_CHECKING, Optional, TypeVar
7+
from warnings import deprecated
78

89
import isort
910
import libcst as cst
@@ -432,6 +433,7 @@ def is_zero_diff(original_code: str, new_code: str) -> bool:
432433
return normalize_code(original_code) == normalize_code(new_code)
433434

434435

436+
@deprecated("")
435437
def replace_optimized_code(
436438
callee_module_paths: set[Path],
437439
candidates: list[OptimizedCandidate],

codeflash/code_utils/formatter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def is_diff_line(line: str) -> bool:
104104
def format_code(
105105
formatter_cmds: list[str],
106106
path: Union[str, Path],
107-
optimized_function: str = "",
107+
optimized_code: str = "",
108108
check_diff: bool = False, # noqa
109109
print_status: bool = True, # noqa
110110
exit_on_failure: bool = True, # noqa
@@ -121,7 +121,8 @@ def format_code(
121121

122122
if check_diff and original_code_lines > 50:
123123
# we dont' count the formatting diff for the optimized function as it should be well-formatted
124-
original_code_without_opfunc = original_code.replace(optimized_function, "")
124+
# TODO: This is not correct, optimized_code is not continuous, Think of a better way for doing this.
125+
original_code_without_opfunc = original_code.replace(optimized_code, "")
125126

126127
original_temp = Path(test_dir_str) / "original_temp.py"
127128
original_temp.write_text(original_code_without_opfunc, encoding="utf8")

codeflash/lsp/beta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization
168168
generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests
169169
]
170170
optimizations_dict = {
171-
candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation}
171+
candidate.optimization_id: {"source_code": candidate.source_code.flat, "explanation": candidate.explanation}
172172
for candidate in optimizations_set.control + optimizations_set.experiment
173173
}
174174

@@ -276,7 +276,7 @@ def perform_function_optimization( # noqa: PLR0911
276276
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
277277
}
278278

279-
optimized_source = best_optimization.candidate.source_code
279+
optimized_source = best_optimization.candidate.source_code.flat
280280
speedup = original_code_baseline.runtime / best_optimization.runtime
281281

282282
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")

codeflash/models/models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class CodeString(BaseModel):
157157
file_path: Optional[Path] = None
158158

159159

160+
# Used to split files by adding a marker at the start of each file followed by the file path.
160161
LINE_SPLITTER_MARKER_PREFIX = "# codeflash-splitter__"
161162

162163

@@ -188,17 +189,17 @@ def markdown(self) -> str:
188189
)
189190

190191
@staticmethod
191-
def parse_splitter_markers(code_with_markers: str) -> dict[str, str]:
192+
def parse_splitter_markers(code_with_markers: str) -> CodeStringsMarkdown:
192193
pattern = rf"{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n"
193194
matches = list(re.finditer(pattern, code_with_markers))
194195

195-
results = {}
196+
results = CodeStringsMarkdown()
196197
for i, match in enumerate(matches):
197198
start = match.end()
198199
end = matches[i + 1].start() if i + 1 < len(matches) else len(code_with_markers)
199200
file_path = match.group(1).strip()
200201
code = code_with_markers[start:end].lstrip("\n")
201-
results[file_path] = code
202+
results.code_strings.append(CodeString(code=code, file_path=Path(file_path)))
202203
return results
203204

204205

@@ -303,7 +304,7 @@ class TestsInFile:
303304

304305
@dataclass(frozen=True)
305306
class OptimizedCandidate:
306-
source_code: str
307+
source_code: CodeStringsMarkdown
307308
explanation: str
308309
optimization_id: str
309310

codeflash/optimization/function_optimizer.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from codeflash.either import Failure, Success, is_successful
6363
from codeflash.models.ExperimentMetadata import ExperimentMetadata
6464
from codeflash.models.models import (
65+
LINE_SPLITTER_MARKER_PREFIX,
6566
BestOptimization,
6667
CodeOptimizationContext,
6768
CodeStringsMarkdown,
@@ -216,7 +217,7 @@ def generate_and_instrument_tests(
216217
revert_to_print=bool(get_pr_number()),
217218
):
218219
generated_results = self.generate_tests_and_optimizations(
219-
testgen_context_code=code_context.testgen_context_code,
220+
testgen_context_code=code_context.testgen_context_code, # TODO: should we send the markdow context for the testgen instead.
220221
read_writable_code=code_context.read_writable_code,
221222
read_only_context_code=code_context.read_only_context_code,
222223
helper_functions=code_context.helper_functions,
@@ -289,7 +290,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
289290

290291
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
291292

292-
code_print(code_context.read_writable_code.flat)
293+
code_print(code_context.read_writable_code.flat) # Should we print the markdown or the flattened code?
293294

294295
test_setup_result = self.generate_and_instrument_tests( # also generates optimizations
295296
code_context, should_run_experiment=should_run_experiment
@@ -414,11 +415,11 @@ def determine_best_candidate(
414415
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
415416
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
416417
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
417-
code_print(candidate.source_code)
418+
code_print(candidate.source_code.flat)
418419
try:
419420
did_update = self.replace_function_and_helpers_with_optimized_code(
420421
code_context=code_context,
421-
optimized_code=candidate.source_code,
422+
optimized_code=candidate.source_code.flat,
422423
original_helper_code=original_helper_code,
423424
)
424425
if not did_update:
@@ -578,7 +579,7 @@ def determine_best_candidate(
578579
runtimes_list = []
579580
for valid_opt in self.valid_optimizations:
580581
diff_lens_list.append(
581-
diff_length(valid_opt.candidate.source_code, code_context.read_writable_code.flat)
582+
diff_length(valid_opt.candidate.source_code.flat, code_context.read_writable_code.flat)
582583
) # char level diff
583584
runtimes_list.append(valid_opt.runtime)
584585
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
@@ -613,7 +614,7 @@ def refine_optimizations(
613614
original_source_code=code_context.read_writable_code.flat,
614615
read_only_dependency_code=code_context.read_only_context_code,
615616
original_code_runtime=humanize_runtime(original_code_baseline.runtime),
616-
optimized_source_code=opt.candidate.source_code,
617+
optimized_source_code=opt.candidate.source_code.flat,
617618
optimized_explanation=opt.candidate.explanation,
618619
optimized_code_runtime=humanize_runtime(opt.runtime),
619620
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=opt.runtime) * 100)}%",
@@ -679,13 +680,13 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
679680
f.write(helper_code)
680681

681682
def reformat_code_and_helpers(
682-
self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_function: str
683+
self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_code: str
683684
) -> tuple[str, dict[Path, str]]:
684685
should_sort_imports = not self.args.disable_imports_sorting
685686
if should_sort_imports and isort.code(original_code) != original_code:
686687
should_sort_imports = False
687688

688-
new_code = format_code(self.args.formatter_cmds, path, optimized_function=optimized_function, check_diff=True)
689+
new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True)
689690
if should_sort_imports:
690691
new_code = sort_imports(new_code)
691692

@@ -694,7 +695,7 @@ def reformat_code_and_helpers(
694695
module_abspath = hp.file_path
695696
hp_source_code = hp.source_code
696697
formatted_helper_code = format_code(
697-
self.args.formatter_cmds, module_abspath, optimized_function=hp_source_code, check_diff=True
698+
self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True
698699
)
699700
if should_sort_imports:
700701
formatted_helper_code = sort_imports(formatted_helper_code)
@@ -711,7 +712,8 @@ def replace_function_and_helpers_with_optimized_code(
711712
self.function_to_optimize.qualified_name
712713
)
713714

714-
file_to_code_context = CodeStringsMarkdown.parse_splitter_markers(optimized_code)
715+
code_strings = CodeStringsMarkdown.parse_splitter_markers(optimized_code).code_strings
716+
file_to_code_context = {str(code_string.file_path): code_string.code for code_string in code_strings}
715717

716718
for helper_function in code_context.helper_functions:
717719
if helper_function.jedi_definition.type != "class":
@@ -721,11 +723,12 @@ def replace_function_and_helpers_with_optimized_code(
721723
relative_module_path = str(module_abspath.relative_to(self.project_root))
722724
logger.debug(f"applying optimized code to: {relative_module_path}")
723725

724-
scoped_optimized_code = file_to_code_context.get(relative_module_path, None)
726+
scoped_optimized_code = file_to_code_context.get(relative_module_path)
725727
if scoped_optimized_code is None:
726728
logger.warning(
727729
f"Optimized code not found for {relative_module_path} In the context\n-------\n{optimized_code}\n-------\n"
728730
"Existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'"
731+
f"existing files are {file_to_code_context.keys()}"
729732
)
730733
scoped_optimized_code = ""
731734

@@ -1063,7 +1066,7 @@ def find_and_process_best_optimization(
10631066

10641067
if best_optimization:
10651068
logger.info("Best candidate:")
1066-
code_print(best_optimization.candidate.source_code)
1069+
code_print(best_optimization.candidate.source_code.flat)
10671070
console.print(
10681071
Panel(
10691072
best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue"
@@ -1089,15 +1092,15 @@ def find_and_process_best_optimization(
10891092

10901093
self.replace_function_and_helpers_with_optimized_code(
10911094
code_context=code_context,
1092-
optimized_code=best_optimization.candidate.source_code,
1095+
optimized_code=best_optimization.candidate.source_code.flat,
10931096
original_helper_code=original_helper_code,
10941097
)
10951098

10961099
new_code, new_helper_code = self.reformat_code_and_helpers(
10971100
code_context.helper_functions,
10981101
explanation.file_path,
10991102
self.function_to_optimize_source_code,
1100-
optimized_function=best_optimization.candidate.source_code,
1103+
optimized_code=best_optimization.candidate.source_code.flat,
11011104
)
11021105

11031106
original_code_combined = original_helper_code.copy()
@@ -1169,10 +1172,14 @@ def process_review(
11691172
optimized_runtimes_all=optimized_runtime_by_test,
11701173
)
11711174
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
1172-
source_code=code_context.read_writable_code,
1175+
source_code=code_context.read_writable_code.flat.replace(
1176+
LINE_SPLITTER_MARKER_PREFIX, "# file: "
1177+
), # for better readability to the LLM
11731178
dependency_code=code_context.read_only_context_code,
11741179
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
1175-
optimized_code=best_optimization.candidate.source_code,
1180+
optimized_code=best_optimization.candidate.source_code.flat.replace(
1181+
LINE_SPLITTER_MARKER_PREFIX, "# file: "
1182+
),
11761183
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
11771184
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
11781185
original_code_runtime=humanize_runtime(original_code_baseline.runtime),

tests/test_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _run_formatting_test(source_code: str, should_content_change: bool, expected
263263
helper_functions=[],
264264
path=target_path,
265265
original_code=optimizer.function_to_optimize_source_code,
266-
optimized_function=optimized_function,
266+
optimized_code=optimized_function,
267267
)
268268

269269
content = target_path.read_text(encoding="utf8")

0 commit comments

Comments
 (0)