Skip to content

Commit 3159233

Browse files
committed
fix test failures
1 parent 1296da2 commit 3159233

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,18 @@
55
import sqlite3
66
import sys
77
import time
8-
from dataclasses import dataclass
98
from pathlib import Path
9+
from typing import TYPE_CHECKING
1010

1111
import pytest
1212

1313
from codeflash.benchmarking.codeflash_trace import codeflash_trace
1414
from codeflash.code_utils.code_utils import module_name_from_file_path
1515

16-
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
17-
18-
19-
@dataclass(frozen=True)
20-
class BenchmarkKey:
21-
module_path: str
22-
function_name: str
16+
if TYPE_CHECKING:
17+
from codeflash.models.models import BenchmarkKey
2318

24-
def __str__(self) -> str:
25-
return f"{self.module_path}::{self.function_name}"
19+
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
2620

2721

2822
class CodeFlashBenchmarkPlugin:
@@ -83,6 +77,8 @@ def close(self) -> None:
8377

8478
@staticmethod
8579
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
80+
from codeflash.models.models import BenchmarkKey
81+
8682
"""Process the trace file and extract timing data for all functions.
8783
8884
Args:
@@ -143,6 +139,8 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
143139

144140
@staticmethod
145141
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
142+
from codeflash.models.models import BenchmarkKey
143+
146144
"""Extract total benchmark timings from trace files.
147145
148146
Args:

codeflash/models/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import re
1414
import sys
1515
from collections.abc import Collection
16+
from dataclasses import dataclass as dcdataclass
1617
from enum import Enum, IntEnum
1718
from pathlib import Path
1819
from re import Pattern
@@ -83,7 +84,7 @@ class BestOptimization(BaseModel):
8384
winning_replay_benchmarking_test_results: Optional[TestResults] = None
8485

8586

86-
@dataclass(frozen=True)
87+
@dcdataclass(frozen=True)
8788
class BenchmarkKey:
8889
module_path: str
8990
function_name: str

0 commit comments

Comments
 (0)