|
4 | 4 | import re |
5 | 5 | from collections import defaultdict |
6 | 6 | from functools import lru_cache |
7 | | -from typing import TYPE_CHECKING, Optional, TypeVar |
| 7 | +from typing import TYPE_CHECKING, List, Optional, TypeVar |
8 | 8 |
|
9 | 9 | import libcst as cst |
10 | 10 |
|
@@ -342,33 +342,35 @@ def function_to_optimize_original_worktree_fqn( |
342 | 342 | class AssertCleanup: |
343 | 343 | def transform_asserts(self, code: str) -> str: |
344 | 344 | lines = code.splitlines() |
345 | | - result_lines = [] |
| 345 | + result_lines: List[str] = [] |
| 346 | + |
| 347 | + append_result = result_lines.append |
| 348 | + transform_line = self._transform_assert_line |
346 | 349 |
|
347 | 350 | for line in lines: |
348 | | - transformed = self._transform_assert_line(line) |
| 351 | + transformed = transform_line(line) |
349 | 352 | if transformed is not None: |
350 | | - result_lines.append(transformed) |
| 353 | + append_result(transformed) |
351 | 354 | else: |
352 | | - result_lines.append(line) |
| 355 | + append_result(line) |
353 | 356 |
|
354 | 357 | return "\n".join(result_lines) |
355 | 358 |
|
356 | 359 | def _transform_assert_line(self, line: str) -> Optional[str]: |
357 | 360 | indent = line[: len(line) - len(line.lstrip())] |
358 | 361 |
|
359 | | - assert_match = re.match(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$", line) |
| 362 | + assert_match = self.assert_pattern.match(line) |
360 | 363 | if assert_match: |
361 | 364 | expression = assert_match.group(1).strip() |
362 | 365 | if expression.startswith("not "): |
363 | 366 | return f"{indent}{expression}" |
364 | 367 |
|
365 | | - expression = re.sub(r"[,;]\s*$", "", expression) |
| 368 | + expression = expression.rstrip(",;") |
366 | 369 | return f"{indent}{expression}" |
367 | 370 |
|
368 | | - unittest_match = re.match(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$", line) |
| 371 | + unittest_match = self.unittest_pattern.match(line) |
369 | 372 | if unittest_match: |
370 | | - indent, assert_method, args = unittest_match.groups() |
371 | | - |
| 373 | + indent, _, args = unittest_match.groups() |
372 | 374 | if args: |
373 | 375 | arg_parts = self._split_top_level_args(args) |
374 | 376 | if arg_parts and arg_parts[0]: |
@@ -399,6 +401,11 @@ def _split_top_level_args(self, args_str: str) -> list[str]: |
399 | 401 |
|
400 | 402 | return result |
401 | 403 |
|
| 404 | + def __init__(self): |
| 405 | + # Compile the regular expressions once |
| 406 | + self.assert_pattern = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$") |
| 407 | + self.unittest_pattern = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$") |
| 408 | + |
402 | 409 |
|
403 | 410 | def clean_concolic_tests(test_suite_code: str) -> str: |
404 | 411 | try: |
|
0 commit comments