|
19 | 19 | from typing import Annotated, Optional, cast |
20 | 20 |
|
21 | 21 | from jedi.api.classes import Name |
22 | | -from pydantic import AfterValidator, BaseModel, ConfigDict, Field |
| 22 | +from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr |
23 | 23 | from pydantic.dataclasses import dataclass |
24 | 24 |
|
25 | 25 | from codeflash.cli_cmds.console import console, logger |
@@ -157,23 +157,96 @@ class CodeString(BaseModel): |
157 | 157 | file_path: Optional[Path] = None |
158 | 158 |
|
159 | 159 |
|
| 160 | +def get_code_block_splitter(file_path: Path) -> str: |
| 161 | + return f"# file: {file_path}" |
| 162 | + |
| 163 | + |
| 164 | +markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL) |
| 165 | + |
| 166 | + |
160 | 167 | class CodeStringsMarkdown(BaseModel): |
161 | 168 | code_strings: list[CodeString] = [] |
| 169 | + _cache: dict = PrivateAttr(default_factory=dict) |
| 170 | + |
| 171 | + @property |
| 172 | + def flat(self) -> str: |
| 173 | + """Returns the combined Python module from all code blocks. |
| 174 | +
|
| 175 | + Each block is prefixed by a file path comment to indicate its origin. |
| 176 | + This representation is syntactically valid Python code. |
| 177 | +
|
| 178 | + Returns: |
| 179 | + str: The concatenated code of all blocks with file path annotations. |
| 180 | +
|
| 181 | + !! Important !!: |
| 182 | + Avoid parsing the flat code with multiple files, |
| 183 | + parsing may result in unexpected behavior. |
| 184 | +
|
| 185 | +
|
| 186 | + """ |
| 187 | + if self._cache.get("flat") is not None: |
| 188 | + return self._cache["flat"] |
| 189 | + self._cache["flat"] = "\n".join( |
| 190 | + get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings |
| 191 | + ) |
| 192 | + return self._cache["flat"] |
162 | 193 |
|
163 | 194 | @property |
164 | 195 | def markdown(self) -> str: |
165 | | - """Returns the markdown representation of the code, including the file path where possible.""" |
| 196 | + """Returns a Markdown-formatted string containing all code blocks. |
| 197 | +
|
| 198 | + Each block is enclosed in a triple-backtick code block with an optional |
| 199 | + file path suffix (e.g., ```python:filename.py). |
| 200 | +
|
| 201 | + Returns: |
| 202 | + str: Markdown representation of the code blocks. |
| 203 | +
|
| 204 | + """ |
166 | 205 | return "\n".join( |
167 | 206 | [ |
168 | 207 | f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```" |
169 | 208 | for code_string in self.code_strings |
170 | 209 | ] |
171 | 210 | ) |
172 | 211 |
|
| 212 | + def file_to_path(self) -> dict[str, str]: |
| 213 | + """Return a dictionary mapping file paths to their corresponding code blocks. |
| 214 | +
|
| 215 | + Returns: |
| 216 | + dict[str, str]: Mapping from file path (as string) to code. |
| 217 | +
|
| 218 | + """ |
| 219 | + if self._cache.get("file_to_path") is not None: |
| 220 | + return self._cache["file_to_path"] |
| 221 | + self._cache["file_to_path"] = { |
| 222 | + str(code_string.file_path): code_string.code for code_string in self.code_strings |
| 223 | + } |
| 224 | + return self._cache["file_to_path"] |
| 225 | + |
| 226 | + @staticmethod |
| 227 | + def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown: |
| 228 | + """Parse a Markdown string into a CodeStringsMarkdown object. |
| 229 | +
|
| 230 | + Extracts code blocks and their associated file paths and constructs a new CodeStringsMarkdown instance. |
| 231 | +
|
| 232 | + Args: |
| 233 | + markdown_code (str): The Markdown-formatted string to parse. |
| 234 | +
|
| 235 | + Returns: |
| 236 | + CodeStringsMarkdown: Parsed object containing code blocks. |
| 237 | +
|
| 238 | + """ |
| 239 | + matches = markdown_pattern.findall(markdown_code) |
| 240 | + results = CodeStringsMarkdown() |
| 241 | + for file_path, code in matches: |
| 242 | + path = file_path.strip() |
| 243 | + results.code_strings.append(CodeString(code=code, file_path=Path(path))) |
| 244 | + return results |
| 245 | + |
173 | 246 |
|
174 | 247 | class CodeOptimizationContext(BaseModel): |
175 | 248 | testgen_context_code: str = "" |
176 | | - read_writable_code: str = Field(min_length=1) |
| 249 | + read_writable_code: CodeStringsMarkdown |
177 | 250 | read_only_context_code: str = "" |
178 | 251 | hashing_code_context: str = "" |
179 | 252 | hashing_code_context_hash: str = "" |
@@ -272,7 +345,7 @@ class TestsInFile: |
272 | 345 |
|
273 | 346 | @dataclass(frozen=True) |
274 | 347 | class OptimizedCandidate: |
275 | | - source_code: str |
| 348 | + source_code: CodeStringsMarkdown |
276 | 349 | explanation: str |
277 | 350 | optimization_id: str |
278 | 351 |
|
|
0 commit comments