Skip to content

Commit f1ba78e

Browse files
fix for tests
1 parent ef6d6ab commit f1ba78e

File tree

3 files changed

+43
-12
lines changed

3 files changed

+43
-12
lines changed

codeflash/api/aiservice.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,21 @@
3232

3333
class AiServiceClient:
3434
def __init__(self) -> None:
35-
# Validate API key before initializing the client
35+
# API key validation is deferred to first use (in the headers property)
36+
self.base_url = self.get_aiservice_base_url()
37+
38+
def get_aiservice_base_url(self) -> str:
39+
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
40+
logger.info("Using local AI Service at http://localhost:8000")
41+
console.rule()
42+
return "http://localhost:8000"
43+
return "https://app.codeflash.ai"
44+
45+
def validate_api_key(self) -> None:
46+
"""Validate API key. Raises OSError if invalid."""
47+
if hasattr(self, '_api_key_validated') and self._api_key_validated:
48+
return
49+
3650
try:
3751
from codeflash.api.cfapi import get_user_id # noqa: PLC0415
3852

@@ -53,15 +67,14 @@ def __init__(self) -> None:
5367
# If cfapi is not available, skip validation
5468
pass
5569

56-
self.base_url = self.get_aiservice_base_url()
57-
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
58-
59-
def get_aiservice_base_url(self) -> str:
60-
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
61-
logger.info("Using local AI Service at http://localhost:8000")
62-
console.rule()
63-
return "http://localhost:8000"
64-
return "https://app.codeflash.ai"
70+
self._api_key_validated = True
71+
72+
@property
73+
def headers(self) -> dict[str, str]:
74+
"""Get headers with API key. Validates API key on first access."""
75+
# Lazily validate API key on first use
76+
self.validate_api_key()
77+
return {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
6578

6679
def make_ai_service_request(
6780
self,

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(
213213
) -> None:
214214
self.project_root = test_cfg.project_root_path
215215
self.test_cfg = test_cfg
216-
self.aiservice_client = aiservice_client if aiservice_client else AiServiceClient()
216+
self._aiservice_client = aiservice_client # Can be None for lazy initialization
217217
self.function_to_optimize = function_to_optimize
218218
self.function_to_optimize_source_code = (
219219
function_to_optimize_source_code
@@ -247,6 +247,13 @@ def __init__(
247247
max_workers=n_tests + 2 if self.experiment_id is None else n_tests + 3
248248
)
249249

250+
@property
251+
def aiservice_client(self) -> AiServiceClient:
252+
"""Lazy initialization of AiServiceClient to delay API key validation."""
253+
if self._aiservice_client is None:
254+
self._aiservice_client = AiServiceClient()
255+
return self._aiservice_client
256+
250257
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
251258
should_run_experiment = self.experiment_id is not None
252259
logger.debug(f"Function Trace ID: {self.function_trace_id}")

codeflash/optimization/optimizer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, args: Namespace) -> None:
5050
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
5151
)
5252

53-
self.aiservice_client = AiServiceClient()
53+
self._aiservice_client: AiServiceClient | None = None
5454
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
5555
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
5656
self.replay_tests_dir = None
@@ -61,6 +61,13 @@ def __init__(self, args: Namespace) -> None:
6161
self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None
6262
self.patch_files: list[Path] = []
6363

64+
@property
65+
def aiservice_client(self) -> AiServiceClient:
66+
"""Lazy initialization of AiServiceClient to delay API key validation."""
67+
if self._aiservice_client is None:
68+
self._aiservice_client = AiServiceClient()
69+
return self._aiservice_client
70+
6471
def run_benchmarks(
6572
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
6673
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
@@ -261,6 +268,10 @@ def run(self) -> None:
261268
console.rule()
262269
if not env_utils.ensure_codeflash_api_key():
263270
return
271+
272+
# Validate API key before starting optimization
273+
self.aiservice_client.validate_api_key()
274+
264275
if self.args.no_draft and is_pr_draft():
265276
logger.warning("PR is in draft mode, skipping optimization")
266277
return

0 commit comments

Comments
 (0)