Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
llm_provider_factory.py
Provides a simple Factory for creating LangChain ChatModel instances
and currently only 2 LLMs - openai and gemini are supported in the factory.
with lazy initialization for the 2 currently supported LLMs - openai and gemini.
"""

import logging
import os
from typing import Optional

from dotenv import load_dotenv
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -60,11 +61,29 @@ def build_gemini(config: LLMConfig) -> BaseChatModel:
)


class LazyLLM:
def __init__(self, config: LLMConfig):
self.config = config
self._llm: Optional[BaseChatModel] = None

@property
def llm(self) -> BaseChatModel:
if self._llm is None:
log.info("Initializing LLM for the first time: provider=%s, model=%s",
self.config.name, self.config.model)
self._llm = LLMServiceFactory.build(self.config)
return self._llm

def __getattr__(self, name): # type: ignore[no-untyped-def]
return getattr(self.llm, name)


class LLMServiceFactory:
PROVIDER_REGISTRY = {
"openai": build_openai,
"gemini": build_gemini,
}
_cached_lazy_llm: Optional[LazyLLM] = None

@classmethod
def build(cls, config: LLMConfig) -> BaseChatModel:
Expand All @@ -73,5 +92,15 @@ def build(cls, config: LLMConfig) -> BaseChatModel:
if provider_name not in cls.PROVIDER_REGISTRY:
log.error("Unsupported LLM provider requested: %s", provider_name)
raise ValueError(f"Unsupported provider: {provider_name}")
log.info("Selected LLM provider=%s, model=%s", provider_name, provider_model)
log.info("Building LLM provider=%s, model=%s", provider_name, provider_model)
return cls.PROVIDER_REGISTRY[provider_name](config)

@classmethod
def build_lazy(cls, config: LLMConfig) -> LazyLLM:
if cls._cached_lazy_llm is None:
log.debug("Creating lazy LLM wrapper for: provider=%s, model=%s", config.name, config.model)
cls._cached_lazy_llm = LazyLLM(config)
else:
log.debug("Reusing cached lazy LLM wrapper for: provider=%s, model=%s", config.name, config.model)

return cls._cached_lazy_llm
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import logging
from typing import Optional

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, ValidationError

from llm_search_quality_evaluation.dataset_generator.llm.llm_provider_factory import LazyLLM
from llm_search_quality_evaluation.dataset_generator.models.query_response import LLMQueryResponse
from llm_search_quality_evaluation.dataset_generator.models.score_response import LLMScoreResponse
from llm_search_quality_evaluation.shared.models.document import Document
Expand All @@ -16,7 +16,7 @@


class LLMService:
def __init__(self, chat_model: BaseChatModel):
def __init__(self, chat_model: LazyLLM):
self.chat_model = chat_model

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions src/llm_search_quality_evaluation/dataset_generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
# ------ temporary import for corpus.json bug workaround ------
import json
from pathlib import Path

from llm_search_quality_evaluation.dataset_generator.llm.llm_provider_factory import LazyLLM
from llm_search_quality_evaluation.shared.utils import _to_string
import argparse
# -------------------------------------------------------------

from typing import List
from langchain_core.language_models import BaseChatModel
from logging import Logger, getLogger

# project imports
Expand Down Expand Up @@ -138,7 +139,7 @@ def main() -> None:
search_engine_type=config.search_engine_type,
endpoint=config.search_engine_collection_endpoint
)
llm: BaseChatModel = LLMServiceFactory.build(LLMConfig.load(config.llm_configuration_file))
llm: LazyLLM = LLMServiceFactory.build_lazy(LLMConfig.load(config.llm_configuration_file))
service: LLMService = LLMService(chat_model=llm)
writer: AbstractWriter = WriterFactory.build(writer_config)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from pydantic_core import ValidationError

from llm_search_quality_evaluation.dataset_generator.llm import LLMConfig, LLMService
from llm_search_quality_evaluation.dataset_generator.llm.llm_provider_factory import LazyLLM, LLMServiceFactory
from llm_search_quality_evaluation.shared.models import Document


@pytest.fixture
def example_doc():
"""Provides a sample Document object for testing."""
return Document(
id="doc1",
fields={
"title": "Car of the Year",
"description": "The Toyota Camry, the nation's most popular car has now been rated as its best new model."
}
)


@pytest.fixture
def query():
return "Is a Toyota the car of the year?"


def test_llm_factory_lazy__expected__llm_none():
cfg = LLMConfig(
name="openai",
model="mock_model",
max_tokens= 1024,
api_key_env="mock_api_key",
)
llm: LazyLLM = LLMServiceFactory.build_lazy(cfg)
assert llm._llm is None

def test_llm_factory_invalid_model_name__expected__validation_error():
with pytest.raises(ValidationError):
_ = LLMConfig(
name="mock_provider",
model="mock_model",
max_tokens= 1024,
api_key_env="mock_api_key",
)



@pytest.mark.parametrize("provider, model", [
("openai", "gpt-5-nano-2025-08-07"),
("gemini", "gemini-3-pro-preview"),
])
def test_llm_factory_lazy_openai__expected__api_key_not_valid(example_doc, query, provider, model):
cfg = LLMConfig(
name=provider,
model=model,
max_tokens=1024,
api_key_env="invalid_api_key",
)
llm: LazyLLM = LLMServiceFactory.build_lazy(cfg)

service: LLMService = LLMService(chat_model=llm)
with pytest.raises(ValueError):
_ = service.generate_score(example_doc, query, relevance_scale='binary')
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_main_passes_autosave_option_to_datastore(monkeypatch, tmp_path: Path):
# Patch factories to avoid network / heavy dependencies
monkeypatch.setattr(main_mod, "SearchEngineFactory", types.SimpleNamespace(build=lambda **kwargs: object()))
monkeypatch.setattr(main_mod, "LLMConfig", types.SimpleNamespace(load=lambda _path: object()))
monkeypatch.setattr(main_mod, "LLMServiceFactory", types.SimpleNamespace(build=lambda _cfg: object()))
monkeypatch.setattr(main_mod, "LLMServiceFactory", types.SimpleNamespace(build_lazy=lambda _cfg: object()))
monkeypatch.setattr(main_mod, "WriterFactory", types.SimpleNamespace(build=lambda _cfg: DummyWriter()))

# No-op the heavy flow functions to keep the test focused on wiring
Expand Down