Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
generate_candidates,
)
from codeflash.code_utils.env_utils import is_end_to_end
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.verification.test_results import TestResults, TestType

# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
Expand Down Expand Up @@ -58,15 +57,19 @@ class FunctionSource:
def __eq__(self, other: object) -> bool:
if not isinstance(other, FunctionSource):
return False
return (self.file_path == other.file_path and
self.qualified_name == other.qualified_name and
self.fully_qualified_name == other.fully_qualified_name and
self.only_function_name == other.only_function_name and
self.source_code == other.source_code)
return (
self.file_path == other.file_path
and self.qualified_name == other.qualified_name
and self.fully_qualified_name == other.fully_qualified_name
and self.only_function_name == other.only_function_name
and self.source_code == other.source_code
)

def __hash__(self) -> int:
return hash((self.file_path, self.qualified_name, self.fully_qualified_name,
self.only_function_name, self.source_code))
return hash(
(self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code)
)


class BestOptimization(BaseModel):
candidate: OptimizedCandidate
Expand All @@ -76,7 +79,8 @@ class BestOptimization(BaseModel):
replay_performance_gain: Optional[float] = None
winning_behavioral_test_results: TestResults
winning_benchmarking_test_results: TestResults
winning_replay_benchmarking_test_results : Optional[TestResults] = None
winning_replay_benchmarking_test_results: Optional[TestResults] = None


@dataclass
class BenchmarkDetail:
Expand All @@ -94,13 +98,9 @@ def to_string(self) -> str:
)

def to_dict(self) -> dict[str, any]:
return {
"benchmark_name": self.benchmark_name,
"test_function": self.test_function,
"original_timing": self.original_timing,
"expected_new_timing": self.expected_new_timing,
"speedup_percent": self.speedup_percent
}
# Utilizing Pydantic's built-in `.dict()` method for efficient serialization
return self.__dict__


@dataclass
class ProcessedBenchmarkInfo:
Expand All @@ -116,9 +116,9 @@ def to_string(self) -> str:
return result

def to_dict(self) -> dict[str, list[dict[str, any]]]:
return {
"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]
}
return {"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]}


class CodeString(BaseModel):
code: Annotated[str, AfterValidator(validate_python_code)]
file_path: Optional[Path] = None
Expand All @@ -143,7 +143,8 @@ class CodeOptimizationContext(BaseModel):
read_writable_code: str = Field(min_length=1)
read_only_context_code: str = ""
helper_functions: list[FunctionSource]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]]
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]


class CodeContextType(str, Enum):
READ_WRITABLE = "READ_WRITABLE"
Expand Down Expand Up @@ -287,13 +288,17 @@ class CoverageData:

@staticmethod
def load_from_sqlite_database(
database_path: Path, config_path: Path, function_name: str, code_context: CodeOptimizationContext, source_code_path: Path
database_path: Path,
config_path: Path,
function_name: str,
code_context: CodeOptimizationContext,
source_code_path: Path,
) -> CoverageData:
"""Load coverage data from an SQLite database, mimicking the behavior of load_from_coverage_file."""
from coverage import Coverage
from coverage.jsonreport import JsonReporter

cov = Coverage(data_file=database_path,config_file=config_path, data_suffix=True, auto_data=True, branch=True)
cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True)

if not database_path.stat().st_size or not database_path.exists():
logger.debug(f"Coverage database {database_path} is empty or does not exist")
Expand Down
Loading