Skip to content

Commit 5c60353

Browse files
committed
safe changes
1 parent 0983ae4 commit 5c60353

File tree

10 files changed

+67
-40
lines changed

10 files changed

+67
-40
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,11 @@ def should_modify_pyproject_toml() -> bool:
141141
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
142142
return True
143143

144-
create_toml = Confirm.ask(
144+
return Confirm.ask(
145145
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?",
146146
default=False,
147147
show_default=True,
148148
)
149-
return create_toml
150149

151150

152151
def collect_setup_info() -> SetupInfo:

codeflash/code_utils/code_extractor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import ast
4-
from pathlib import Path
54
from typing import TYPE_CHECKING
65

76
import libcst as cst
@@ -11,12 +10,15 @@
1110
from libcst.helpers import calculate_module_and_package
1211

1312
from codeflash.cli_cmds.console import logger
14-
from codeflash.models.models import FunctionParent, FunctionSource
13+
from codeflash.models.models import FunctionParent
1514

1615
if TYPE_CHECKING:
16+
from pathlib import Path
17+
1718
from libcst.helpers import ModuleNameAndPackage
1819

1920
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
21+
from codeflash.models.models import FunctionSource
2022

2123

2224
class FutureAliasedImportTransformer(cst.CSTTransformer):

codeflash/context/code_context_extractor.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,15 @@
33
import os
44
from collections import defaultdict
55
from itertools import chain
6-
from pathlib import Path
6+
from typing import TYPE_CHECKING, Optional
77

88
import jedi
99
import libcst as cst
1010
import tiktoken
11-
from jedi.api.classes import Name
12-
from libcst import CSTNode
1311

1412
from codeflash.cli_cmds.console import logger
1513
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
1614
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
17-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1815
from codeflash.models.models import (
1916
CodeContextType,
2017
CodeOptimizationContext,
@@ -24,6 +21,14 @@
2421
)
2522
from codeflash.optimization.function_context import belongs_to_function_qualified
2623

24+
if TYPE_CHECKING:
25+
from pathlib import Path
26+
27+
from jedi.api.classes import Name
28+
from libcst import CSTNode
29+
30+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
31+
2732

2833
def get_code_optimization_context(
2934
function_to_optimize: FunctionToOptimize,
@@ -75,7 +80,8 @@ def get_code_optimization_context(
7580
tokenizer = tiktoken.encoding_for_model("gpt-4o")
7681
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
7782
if final_read_writable_tokens > optim_token_limit:
78-
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
83+
msg = "Read-writable code has exceeded token limit, cannot proceed"
84+
raise ValueError(msg)
7985

8086
# Setup preexisting objects for code replacer
8187
preexisting_objects = set(
@@ -122,7 +128,8 @@ def get_code_optimization_context(
122128
testgen_context_code = testgen_code_markdown.code
123129
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
124130
if testgen_context_code_tokens > testgen_token_limit:
125-
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
131+
msg = "Testgen code context has exceeded token limit, cannot proceed"
132+
raise ValueError(msg)
126133

127134
return CodeOptimizationContext(
128135
testgen_context_code=testgen_context_code,
@@ -143,7 +150,7 @@ def extract_code_string_context_from_files(
143150
"""Extract code context from files containing target functions and their helpers.
144151
This function processes two sets of files:
145152
1. Files containing the function to optimize (fto) and their first-degree helpers
146-
2. Files containing only helpers of helpers (with no overlap with the first set)
153+
2. Files containing only helpers of helpers (with no overlap with the first set).
147154
148155
For each file, it extracts relevant code based on the specified context type, adds necessary
149156
imports, and combines them.
@@ -358,18 +365,18 @@ def get_function_to_optimize_as_function_source(
358365
and name.name == function_to_optimize.function_name
359366
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
360367
):
361-
function_source = FunctionSource(
368+
return FunctionSource(
362369
file_path=function_to_optimize.file_path,
363370
qualified_name=function_to_optimize.qualified_name,
364371
fully_qualified_name=name.full_name,
365372
only_function_name=name.name,
366373
source_code=name.get_line_code(),
367374
jedi_definition=name,
368375
)
369-
return function_source
370376

377+
msg = f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
371378
raise ValueError(
372-
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
379+
msg
373380
)
374381

375382

@@ -436,7 +443,7 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
436443

437444

438445
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
439-
"""Removes the docstring from an indented block if it exists"""
446+
"""Removes the docstring from an indented block if it exists."""
440447
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
441448
return indented_block
442449
first_stmt = indented_block.body[0].body[0]
@@ -449,10 +456,12 @@ def parse_code_and_prune_cst(
449456
code: str,
450457
code_context_type: CodeContextType,
451458
target_functions: set[str],
452-
helpers_of_helper_functions: set[str] = set(),
459+
helpers_of_helper_functions: Optional[set[str]] = None,
453460
remove_docstrings: bool = False,
454461
) -> str:
455462
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
463+
if helpers_of_helper_functions is None:
464+
helpers_of_helper_functions = set()
456465
module = cst.parse_module(code)
457466
if code_context_type == CodeContextType.READ_WRITABLE:
458467
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
@@ -465,10 +474,12 @@ def parse_code_and_prune_cst(
465474
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
466475
)
467476
else:
468-
raise ValueError(f"Unknown code_context_type: {code_context_type}")
477+
msg = f"Unknown code_context_type: {code_context_type}"
478+
raise ValueError(msg)
469479

470480
if not found_target:
471-
raise ValueError("No target functions found in the provided code")
481+
msg = "No target functions found in the provided code"
482+
raise ValueError(msg)
472483
if filtered_node and isinstance(filtered_node, cst.Module):
473484
return str(filtered_node.code)
474485
return ""
@@ -500,7 +511,8 @@ def prune_cst_for_read_writable_code(
500511
return None, False
501512
# Assuming always an IndentedBlock
502513
if not isinstance(node.body, cst.IndentedBlock):
503-
raise ValueError("ClassDef body is not an IndentedBlock")
514+
msg = "ClassDef body is not an IndentedBlock"
515+
raise ValueError(msg)
504516
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
505517
new_body = []
506518
found_target = False
@@ -593,7 +605,8 @@ def prune_cst_for_read_only_code(
593605
return None, False
594606
# Assuming always an IndentedBlock
595607
if not isinstance(node.body, cst.IndentedBlock):
596-
raise ValueError("ClassDef body is not an IndentedBlock")
608+
msg = "ClassDef body is not an IndentedBlock"
609+
raise ValueError(msg)
597610

598611
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
599612

@@ -698,7 +711,8 @@ def prune_cst_for_testgen_code(
698711
return None, False
699712
# Assuming always an IndentedBlock
700713
if not isinstance(node.body, cst.IndentedBlock):
701-
raise ValueError("ClassDef body is not an IndentedBlock")
714+
msg = "ClassDef body is not an IndentedBlock"
715+
raise ValueError(msg)
702716

703717
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
704718

codeflash/discovery/pytest_new_process_discovery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def pytest_collection_finish(self, session) -> None:
1616
collected_tests.extend(session.items)
1717
pytest_rootdir = session.config.rootdir
1818

19-
def pytest_collection_modifyitems(config, items):
19+
def pytest_collection_modifyitems(self, items) -> None:
2020
skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests")
2121
for item in items:
2222
if "benchmark" in item.fixturenames:

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import enum
1212
import re
1313
import sys
14-
from collections.abc import Collection, Iterator
14+
from collections.abc import Collection
1515
from enum import Enum, IntEnum
1616
from pathlib import Path
1717
from re import Pattern

codeflash/optimization/function_context.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

3-
from jedi.api.classes import Name
3+
from typing import TYPE_CHECKING
44

55
from codeflash.code_utils.code_utils import get_qualified_name
66

7+
if TYPE_CHECKING:
8+
from jedi.api.classes import Name
9+
710

811
def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool:
912
"""Check if the given name belongs to the specified method."""
@@ -14,9 +17,8 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
1417
"""Check if the given jedi Name is a direct child of the specified function."""
1518
if name.name == function_name: # Handles function definition and recursive function calls
1619
return False
17-
if name := name.parent():
18-
if name.type == "function":
19-
return name.name == function_name
20+
if (name := name.parent()) and name.type == "function":
21+
return name.name == function_name
2022
return False
2123

2224

@@ -34,9 +36,8 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
3436
if get_qualified_name(name.module_name, name.full_name) == qualified_function_name:
3537
# Handles function definition and recursive function calls
3638
return False
37-
if name := name.parent():
38-
if name.type == "function":
39-
return get_qualified_name(name.module_name, name.full_name) == qualified_function_name
39+
if (name := name.parent()) and name.type == "function":
40+
return get_qualified_name(name.module_name, name.full_name) == qualified_function_name
4041
return False
4142
except ValueError:
4243
return False

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
4242
from codeflash.code_utils.time_utils import humanize_runtime
4343
from codeflash.context import code_context_extractor
44-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
4544
from codeflash.either import Failure, Success, is_successful
4645
from codeflash.models.ExperimentMetadata import ExperimentMetadata
4746
from codeflash.models.models import (
@@ -74,6 +73,7 @@
7473
if TYPE_CHECKING:
7574
from argparse import Namespace
7675

76+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
7777
from codeflash.either import Result
7878
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
7979
from codeflash.verification.verification_utils import TestConfig

codeflash/tracing/replay_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
import sqlite3
44
import textwrap
5-
from collections.abc import Generator
6-
from typing import Any, Optional
5+
from typing import TYPE_CHECKING, Any, Optional
76

8-
from codeflash.discovery.functions_to_optimize import FunctionProperties, inspect_top_level_functions_or_methods
9-
from codeflash.tracing.tracing_utils import FunctionModules
7+
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Generator
11+
12+
from codeflash.discovery.functions_to_optimize import FunctionProperties
13+
from codeflash.tracing.tracing_utils import FunctionModules
1014

1115

1216
def get_next_arg_and_return(

codeflash/verification/concolic_testing.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@
33
import ast
44
import subprocess
55
import tempfile
6-
from argparse import Namespace
76
from pathlib import Path
7+
from typing import TYPE_CHECKING
88

99
from codeflash.cli_cmds.console import console, logger
1010
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
1111
from codeflash.code_utils.concolic_utils import clean_concolic_tests
1212
from codeflash.code_utils.static_analysis import has_typed_parameters
1313
from codeflash.discovery.discover_unit_tests import discover_unit_tests
14-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
15-
from codeflash.models.models import FunctionCalledInTest
1614
from codeflash.telemetry.posthog_cf import ph
1715
from codeflash.verification.verification_utils import TestConfig
1816

17+
if TYPE_CHECKING:
18+
from argparse import Namespace
19+
20+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
21+
from codeflash.models.models import FunctionCalledInTest
22+
1923

2024
def generate_concolic_tests(
2125
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST

codeflash/verification/instrument_codeflash_capture.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
import ast
44
from pathlib import Path
5+
from typing import TYPE_CHECKING
56

67
import isort
78

89
from codeflash.code_utils.code_utils import get_run_tmp_file
9-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
10+
11+
if TYPE_CHECKING:
12+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1013

1114

1215
def instrument_codeflash_capture(
@@ -109,7 +112,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
109112
func=ast.Name(id="codeflash_capture", ctx=ast.Load()),
110113
args=[],
111114
keywords=[
112-
ast.keyword(arg="function_name", value=ast.Constant(value=".".join([node.name, "__init__"]))),
115+
ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")),
113116
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
114117
ast.keyword(arg="tests_root", value=ast.Constant(value=str(self.tests_root))),
115118
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),

0 commit comments

Comments
 (0)