diff --git a/src/llm_search_quality_evaluation/dataset_generator/llm/llm_provider_factory.py b/src/llm_search_quality_evaluation/dataset_generator/llm/llm_provider_factory.py index 64cee10..dae092a 100644 --- a/src/llm_search_quality_evaluation/dataset_generator/llm/llm_provider_factory.py +++ b/src/llm_search_quality_evaluation/dataset_generator/llm/llm_provider_factory.py @@ -2,12 +2,13 @@ llm_provider_factory.py Provides a simple Factory for creating LangChain ChatModel instances -and currently only 2 LLMs - openai and gemini are supported in the factory. +with lazy initialization for the 2 currently supported LLMs - openai and gemini. """ import logging import os +from typing import Optional from dotenv import load_dotenv from langchain_core.language_models import BaseChatModel @@ -60,11 +61,29 @@ def build_gemini(config: LLMConfig) -> BaseChatModel: ) +class LazyLLM: + def __init__(self, config: LLMConfig): + self.config = config + self._llm: Optional[BaseChatModel] = None + + @property + def llm(self) -> BaseChatModel: + if self._llm is None: + log.info("Initializing LLM for the first time: provider=%s, model=%s", + self.config.name, self.config.model) + self._llm = LLMServiceFactory.build(self.config) + return self._llm + + def __getattr__(self, name): # type: ignore[no-untyped-def] + return getattr(self.llm, name) + + class LLMServiceFactory: PROVIDER_REGISTRY = { "openai": build_openai, "gemini": build_gemini, } + _cached_lazy_llm: Optional[LazyLLM] = None @classmethod def build(cls, config: LLMConfig) -> BaseChatModel: @@ -73,5 +92,15 @@ def build(cls, config: LLMConfig) -> BaseChatModel: if provider_name not in cls.PROVIDER_REGISTRY: log.error("Unsupported LLM provider requested: %s", provider_name) raise ValueError(f"Unsupported provider: {provider_name}") - log.info("Selected LLM provider=%s, model=%s", provider_name, provider_model) + log.info("Building LLM provider=%s, model=%s", provider_name, provider_model) return cls.PROVIDER_REGISTRY[provider_name](config) + + @classmethod + def build_lazy(cls, config: LLMConfig) -> LazyLLM: + if cls._cached_lazy_llm is None: + log.debug("Creating lazy LLM wrapper for: provider=%s, model=%s", config.name, config.model) + cls._cached_lazy_llm = LazyLLM(config) + else: + log.debug("Reusing cached lazy LLM wrapper for: provider=%s, model=%s", config.name, config.model) + + return cls._cached_lazy_llm diff --git a/src/llm_search_quality_evaluation/dataset_generator/llm/llm_service.py b/src/llm_search_quality_evaluation/dataset_generator/llm/llm_service.py index 716b0f8..565671e 100644 --- a/src/llm_search_quality_evaluation/dataset_generator/llm/llm_service.py +++ b/src/llm_search_quality_evaluation/dataset_generator/llm/llm_service.py @@ -2,10 +2,10 @@ import logging from typing import Optional -from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage from pydantic import BaseModel, ValidationError +from llm_search_quality_evaluation.dataset_generator.llm.llm_provider_factory import LazyLLM from llm_search_quality_evaluation.dataset_generator.models.query_response import LLMQueryResponse from llm_search_quality_evaluation.dataset_generator.models.score_response import LLMScoreResponse from llm_search_quality_evaluation.shared.models.document import Document @@ -16,7 +16,7 @@ class LLMService: - def __init__(self, chat_model: BaseChatModel): + def __init__(self, chat_model: LazyLLM): self.chat_model = chat_model @staticmethod diff --git a/src/llm_search_quality_evaluation/dataset_generator/main.py b/src/llm_search_quality_evaluation/dataset_generator/main.py index 2e9df89..155e6b3 100644 --- a/src/llm_search_quality_evaluation/dataset_generator/main.py +++ b/src/llm_search_quality_evaluation/dataset_generator/main.py @@ -3,12 +3,13 @@ # ------ temporary import for corpus.json bug workaround ------ import json from pathlib import Path + +from llm_search_quality_evaluation.dataset_generator.llm.llm_provider_factory import LazyLLM from llm_search_quality_evaluation.shared.utils import _to_string import argparse # ------------------------------------------------------------- from typing import List -from langchain_core.language_models import BaseChatModel from logging import Logger, getLogger # project imports @@ -138,7 +139,7 @@ def main() -> None: search_engine_type=config.search_engine_type, endpoint=config.search_engine_collection_endpoint ) - llm: BaseChatModel = LLMServiceFactory.build(LLMConfig.load(config.llm_configuration_file)) + llm: LazyLLM = LLMServiceFactory.build_lazy(LLMConfig.load(config.llm_configuration_file)) service: LLMService = LLMService(chat_model=llm) writer: AbstractWriter = WriterFactory.build(writer_config) diff --git a/tests/llm_search_quality_evaluation/dataset_generator/llm/test_llm_factory.py b/tests/llm_search_quality_evaluation/dataset_generator/llm/test_llm_factory.py new file mode 100644 index 0000000..9845468 --- /dev/null +++ b/tests/llm_search_quality_evaluation/dataset_generator/llm/test_llm_factory.py @@ -0,0 +1,62 @@ +import pytest +from pydantic_core import ValidationError + +from llm_search_quality_evaluation.dataset_generator.llm import LLMConfig, LLMService +from llm_search_quality_evaluation.dataset_generator.llm.llm_provider_factory import LazyLLM, LLMServiceFactory +from llm_search_quality_evaluation.shared.models import Document + + +@pytest.fixture +def example_doc(): + """Provides a sample Document object for testing.""" + return Document( + id="doc1", + fields={ + "title": "Car of the Year", + "description": "The Toyota Camry, the nation's most popular car has now been rated as its best new model." + } + ) + + +@pytest.fixture +def query(): + return "Is a Toyota the car of the year?" + + +def test_llm_factory_lazy__expected__llm_none(): + cfg = LLMConfig( + name="openai", + model="mock_model", + max_tokens= 1024, + api_key_env="mock_api_key", + ) + llm: LazyLLM = LLMServiceFactory.build_lazy(cfg) + assert llm._llm is None + +def test_llm_factory_invalid_model_name__expected__validation_error(): + with pytest.raises(ValidationError): + _ = LLMConfig( + name="mock_provider", + model="mock_model", + max_tokens= 1024, + api_key_env="mock_api_key", + ) + + + +@pytest.mark.parametrize("provider, model", [ + ("openai", "gpt-5-nano-2025-08-07"), + ("gemini", "gemini-3-pro-preview"), +]) +def test_llm_factory_lazy_openai__expected__api_key_not_valid(example_doc, query, provider, model): + cfg = LLMConfig( + name=provider, + model=model, + max_tokens=1024, + api_key_env="invalid_api_key", + ) + llm: LazyLLM = LLMServiceFactory.build_lazy(cfg) + + service: LLMService = LLMService(chat_model=llm) + with pytest.raises(ValueError): + _ = service.generate_score(example_doc, query, relevance_scale='binary') diff --git a/tests/llm_search_quality_evaluation/dataset_generator/test_main_autosave.py b/tests/llm_search_quality_evaluation/dataset_generator/test_main_autosave.py index e203aed..de0bf5c 100644 --- a/tests/llm_search_quality_evaluation/dataset_generator/test_main_autosave.py +++ b/tests/llm_search_quality_evaluation/dataset_generator/test_main_autosave.py @@ -55,7 +55,7 @@ def test_main_passes_autosave_option_to_datastore(monkeypatch, tmp_path: Path): # Patch factories to avoid network / heavy dependencies monkeypatch.setattr(main_mod, "SearchEngineFactory", types.SimpleNamespace(build=lambda **kwargs: object())) monkeypatch.setattr(main_mod, "LLMConfig", types.SimpleNamespace(load=lambda _path: object())) - monkeypatch.setattr(main_mod, "LLMServiceFactory", types.SimpleNamespace(build=lambda _cfg: object())) + monkeypatch.setattr(main_mod, "LLMServiceFactory", types.SimpleNamespace(build_lazy=lambda _cfg: object())) monkeypatch.setattr(main_mod, "WriterFactory", types.SimpleNamespace(build=lambda _cfg: DummyWriter())) # No-op the heavy flow functions to keep the test focused on wiring