Skip to content

Commit 3171bb8

Browse files
Merge pull request #553 from codeflash-ai/feat/markdown-read-writable-context
[FEAT] Multi-file context (CF-687) (CF-387) (CF-640)
2 parents 9e12e94 + 5573c46 commit 3171bb8

File tree

14 files changed

+539
-187
lines changed

14 files changed

+539
-187
lines changed
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,2 @@
11
DEFAULT_API_URL = "https://api.galileo.ai/"
22
DEFAULT_APP_URL = "https://app.galileo.ai/"
3-
4-
5-
# function_names: GalileoApiClient.get_console_url
6-
# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py
7-
# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))}
8-
# project_root_path: /home/mohammed/Work/galileo-python/src

codeflash/api/aiservice.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1515
from codeflash.models.ExperimentMetadata import ExperimentMetadata
16-
from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate
16+
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
1717
from codeflash.telemetry.posthog_cf import ph
1818
from codeflash.version import __version__ as codeflash_version
1919

@@ -136,7 +136,7 @@ def optimize_python_code( # noqa: D417
136136
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
137137
return [
138138
OptimizedCandidate(
139-
source_code=opt["source_code"],
139+
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
140140
explanation=opt["explanation"],
141141
optimization_id=opt["optimization_id"],
142142
)
@@ -206,7 +206,7 @@ def optimize_python_code_line_profiler( # noqa: D417
206206
console.rule()
207207
return [
208208
OptimizedCandidate(
209-
source_code=opt["source_code"],
209+
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
210210
explanation=opt["explanation"],
211211
optimization_id=opt["optimization_id"],
212212
)
@@ -263,7 +263,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
263263
console.rule()
264264
return [
265265
OptimizedCandidate(
266-
source_code=opt["source_code"],
266+
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
267267
explanation=opt["explanation"],
268268
optimization_id=opt["optimization_id"][:-4] + "refi",
269269
)

codeflash/code_utils/code_replacer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pathlib import Path
2020

2121
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
22-
from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate, ValidCode
22+
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
2323

2424
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
2525

@@ -408,16 +408,17 @@ def replace_functions_and_add_imports(
408408

409409
def replace_function_definitions_in_module(
410410
function_names: list[str],
411-
optimized_code: str,
411+
optimized_code: CodeStringsMarkdown,
412412
module_abspath: Path,
413413
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
414414
project_root_path: Path,
415415
) -> bool:
416416
source_code: str = module_abspath.read_text(encoding="utf8")
417+
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
417418
new_code: str = replace_functions_and_add_imports(
418-
add_global_assignments(optimized_code, source_code),
419+
add_global_assignments(code_to_apply, source_code),
419420
function_names,
420-
optimized_code,
421+
code_to_apply,
421422
module_abspath,
422423
preexisting_objects,
423424
project_root_path,
@@ -428,6 +429,19 @@ def replace_function_definitions_in_module(
428429
return True
429430

430431

432+
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
433+
file_to_code_context = optimized_code.file_to_path()
434+
module_optimized_code = file_to_code_context.get(str(relative_path))
435+
if module_optimized_code is None:
436+
logger.warning(
437+
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
438+
"re-check your 'markdown code structure'"
439+
f"existing files are {file_to_code_context.keys()}"
440+
)
441+
module_optimized_code = ""
442+
return module_optimized_code
443+
444+
431445
def is_zero_diff(original_code: str, new_code: str) -> bool:
432446
return normalize_code(original_code) == normalize_code(new_code)
433447

codeflash/code_utils/formatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def is_diff_line(line: str) -> bool:
104104
def format_code(
105105
formatter_cmds: list[str],
106106
path: Union[str, Path],
107-
optimized_function: str = "",
107+
optimized_code: str = "",
108108
check_diff: bool = False, # noqa
109109
print_status: bool = True, # noqa
110110
exit_on_failure: bool = True, # noqa
@@ -121,7 +121,7 @@ def format_code(
121121

122122
if check_diff and original_code_lines > 50:
123123
# we dont' count the formatting diff for the optimized function as it should be well-formatted
124-
original_code_without_opfunc = original_code.replace(optimized_function, "")
124+
original_code_without_opfunc = original_code.replace(optimized_code, "")
125125

126126
original_temp = Path(test_dir_str) / "original_temp.py"
127127
original_temp.write_text(original_code_without_opfunc, encoding="utf8")

codeflash/context/code_context_extractor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,14 @@ def get_code_optimization_context(
6161
)
6262

6363
# Extract code context for optimization
64-
final_read_writable_code = extract_code_string_context_from_files(
64+
final_read_writable_code = extract_code_markdown_context_from_files(
6565
helpers_of_fto_dict,
6666
{},
6767
project_root_path,
6868
remove_docstrings=False,
6969
code_context_type=CodeContextType.READ_WRITABLE,
70-
).code
70+
)
71+
7172
read_only_code_markdown = extract_code_markdown_context_from_files(
7273
helpers_of_fto_dict,
7374
helpers_of_helpers_dict,
@@ -84,14 +85,14 @@ def get_code_optimization_context(
8485
)
8586

8687
# Handle token limits
87-
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
88+
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown)
8889
if final_read_writable_tokens > optim_token_limit:
8990
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
9091

9192
# Setup preexisting objects for code replacer
9293
preexisting_objects = set(
9394
chain(
94-
find_preexisting_objects(final_read_writable_code),
95+
*(find_preexisting_objects(codestring.code) for codestring in final_read_writable_code.code_strings),
9596
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
9697
)
9798
)

codeflash/context/unused_definition_remover.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33
import ast
44
from collections import defaultdict
55
from dataclasses import dataclass, field
6-
from typing import TYPE_CHECKING
7-
8-
if TYPE_CHECKING:
9-
from pathlib import Path
6+
from itertools import chain
7+
from pathlib import Path
108
from typing import TYPE_CHECKING, Optional
119

1210
import libcst as cst
1311

1412
from codeflash.cli_cmds.console import logger
1513
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
14+
from codeflash.models.models import CodeString, CodeStringsMarkdown
1615

1716
if TYPE_CHECKING:
1817
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@@ -530,7 +529,11 @@ def revert_unused_helper_functions(
530529
helper_names = [helper.qualified_name for helper in helpers_in_file]
531530
reverted_code = replace_function_definitions_in_module(
532531
function_names=helper_names,
533-
optimized_code=original_code, # Use original code as the "optimized" code to revert
532+
optimized_code=CodeStringsMarkdown(
533+
code_strings=[
534+
CodeString(code=original_code, file_path=Path(file_path).relative_to(project_root))
535+
]
536+
), # Use original code as the "optimized" code to revert
534537
module_abspath=file_path,
535538
preexisting_objects=set(), # Empty set since we're reverting
536539
project_root_path=project_root,
@@ -609,7 +612,9 @@ def _analyze_imports_in_optimized_code(
609612

610613

611614
def detect_unused_helper_functions(
612-
function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, optimized_code: str
615+
function_to_optimize: FunctionToOptimize,
616+
code_context: CodeOptimizationContext,
617+
optimized_code: str | CodeStringsMarkdown,
613618
) -> list[FunctionSource]:
614619
"""Detect helper functions that are no longer called by the optimized entrypoint function.
615620
@@ -622,6 +627,14 @@ def detect_unused_helper_functions(
622627
List of FunctionSource objects representing unused helper functions
623628
624629
"""
630+
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:
631+
return list(
632+
chain.from_iterable(
633+
detect_unused_helper_functions(function_to_optimize, code_context, code.code)
634+
for code in optimized_code.code_strings
635+
)
636+
)
637+
625638
try:
626639
# Parse the optimized code to analyze function calls and imports
627640
optimized_ast = ast.parse(optimized_code)

codeflash/lsp/beta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization
222222
generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests
223223
]
224224
optimizations_dict = {
225-
candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation}
225+
candidate.optimization_id: {"source_code": candidate.source_code.markdown, "explanation": candidate.explanation}
226226
for candidate in optimizations_set.control + optimizations_set.experiment
227227
}
228228

@@ -330,7 +330,7 @@ def perform_function_optimization( # noqa: PLR0911
330330
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
331331
}
332332

333-
optimized_source = best_optimization.candidate.source_code
333+
optimized_source = best_optimization.candidate.source_code.markdown
334334
speedup = original_code_baseline.runtime / best_optimization.runtime
335335

336336
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")

codeflash/models/models.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Annotated, Optional, cast
2020

2121
from jedi.api.classes import Name
22-
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
22+
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
2323
from pydantic.dataclasses import dataclass
2424

2525
from codeflash.cli_cmds.console import console, logger
@@ -157,23 +157,96 @@ class CodeString(BaseModel):
157157
file_path: Optional[Path] = None
158158

159159

160+
def get_code_block_splitter(file_path: Path) -> str:
161+
return f"# file: {file_path}"
162+
163+
164+
markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL)
165+
166+
160167
class CodeStringsMarkdown(BaseModel):
161168
code_strings: list[CodeString] = []
169+
_cache: dict = PrivateAttr(default_factory=dict)
170+
171+
@property
172+
def flat(self) -> str:
173+
"""Returns the combined Python module from all code blocks.
174+
175+
Each block is prefixed by a file path comment to indicate its origin.
176+
This representation is syntactically valid Python code.
177+
178+
Returns:
179+
str: The concatenated code of all blocks with file path annotations.
180+
181+
!! Important !!:
182+
Avoid parsing the flat code with multiple files,
183+
parsing may result in unexpected behavior.
184+
185+
186+
"""
187+
if self._cache.get("flat") is not None:
188+
return self._cache["flat"]
189+
self._cache["flat"] = "\n".join(
190+
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
191+
)
192+
return self._cache["flat"]
162193

163194
@property
164195
def markdown(self) -> str:
165-
"""Returns the markdown representation of the code, including the file path where possible."""
196+
"""Returns a Markdown-formatted string containing all code blocks.
197+
198+
Each block is enclosed in a triple-backtick code block with an optional
199+
file path suffix (e.g., ```python:filename.py).
200+
201+
Returns:
202+
str: Markdown representation of the code blocks.
203+
204+
"""
166205
return "\n".join(
167206
[
168207
f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
169208
for code_string in self.code_strings
170209
]
171210
)
172211

212+
def file_to_path(self) -> dict[str, str]:
213+
"""Return a dictionary mapping file paths to their corresponding code blocks.
214+
215+
Returns:
216+
dict[str, str]: Mapping from file path (as string) to code.
217+
218+
"""
219+
if self._cache.get("file_to_path") is not None:
220+
return self._cache["file_to_path"]
221+
self._cache["file_to_path"] = {
222+
str(code_string.file_path): code_string.code for code_string in self.code_strings
223+
}
224+
return self._cache["file_to_path"]
225+
226+
@staticmethod
227+
def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
228+
"""Parse a Markdown string into a CodeStringsMarkdown object.
229+
230+
Extracts code blocks and their associated file paths and constructs a new CodeStringsMarkdown instance.
231+
232+
Args:
233+
markdown_code (str): The Markdown-formatted string to parse.
234+
235+
Returns:
236+
CodeStringsMarkdown: Parsed object containing code blocks.
237+
238+
"""
239+
matches = markdown_pattern.findall(markdown_code)
240+
results = CodeStringsMarkdown()
241+
for file_path, code in matches:
242+
path = file_path.strip()
243+
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
244+
return results
245+
173246

174247
class CodeOptimizationContext(BaseModel):
175248
testgen_context_code: str = ""
176-
read_writable_code: str = Field(min_length=1)
249+
read_writable_code: CodeStringsMarkdown
177250
read_only_context_code: str = ""
178251
hashing_code_context: str = ""
179252
hashing_code_context_hash: str = ""
@@ -272,7 +345,7 @@ class TestsInFile:
272345

273346
@dataclass(frozen=True)
274347
class OptimizedCandidate:
275-
source_code: str
348+
source_code: CodeStringsMarkdown
276349
explanation: str
277350
optimization_id: str
278351

0 commit comments

Comments
 (0)