diff --git a/src/llm_search_quality_evaluation/dataset_generator/config.py b/src/llm_search_quality_evaluation/dataset_generator/config.py index 3cd84ed..055cce7 100644 --- a/src/llm_search_quality_evaluation/dataset_generator/config.py +++ b/src/llm_search_quality_evaluation/dataset_generator/config.py @@ -6,6 +6,7 @@ from urllib.parse import urljoin from llm_search_quality_evaluation.shared.writers import WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat log = logging.getLogger(__name__) @@ -32,7 +33,7 @@ class Config(BaseModel): llm_configuration_file: FilePath = Field(..., description="Path to the LLM configuration file.") max_query_terms: Optional[int] = Field(None, gt=0, description="Max number of query terms in the LLM-generated " "query") - output_format: Literal['quepid', 'rre', 'mteb'] + output_format: OutputFormat output_destination: Path = Field(..., description="Path to save the output dataset.") save_llm_explanation: bool = False llm_explanation_destination: Optional[Path] = Field(None, description="Path to save the LLM rating explanation") @@ -126,11 +127,13 @@ def search_engine_collection_endpoint(self) -> HttpUrl: @model_validator(mode="after") def check_rre_fields_required(self) -> "Config": - if self.output_format == "rre" and not self.id_field: + if self.output_format != OutputFormat.RRE: + return self + if not self.id_field: raise ValueError("id_field is required when output_format='rre'") - elif self.output_format == "rre" and not self.rre_query_placeholder: + elif not self.rre_query_placeholder: raise ValueError("rre_query_placeholder is required when output_format='rre'") - elif self.output_format == "rre" and not self.rre_query_template and not self.query_template: + elif not self.rre_query_template and not self.query_template: raise ValueError("At least one query template is required when output_format='rre'") return self diff --git a/src/llm_search_quality_evaluation/dataset_generator/main.py b/src/llm_search_quality_evaluation/dataset_generator/main.py index 155e6b3..e5bad50 100644 --- a/src/llm_search_quality_evaluation/dataset_generator/main.py +++ b/src/llm_search_quality_evaluation/dataset_generator/main.py @@ -17,6 +17,7 @@ from llm_search_quality_evaluation.dataset_generator.llm import LLMConfig, LLMService, LLMServiceFactory from llm_search_quality_evaluation.shared.models import Document, Query from llm_search_quality_evaluation.shared.writers import WriterFactory, AbstractWriter, WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat from llm_search_quality_evaluation.shared.search_engines import SearchEngineFactory, BaseSearchEngine from llm_search_quality_evaluation.shared.data_store import DataStore from llm_search_quality_evaluation.shared.utils import join_fields_as_text @@ -171,7 +172,7 @@ def main() -> None: # TODO: # work on a better solution, instead of overwriting the corpus.json file, and maybe modify the MtebWriter with the # fetch from the search engine - if config.output_format == "mteb": + if config.output_format == OutputFormat.MTEB: # copy pasted from MtebWriter corpus_path = Path(output_destination) / "corpus.jsonl" corpus_path.unlink(missing_ok=True) diff --git a/src/llm_search_quality_evaluation/shared/models/output_format.py b/src/llm_search_quality_evaluation/shared/models/output_format.py new file mode 100644 index 0000000..1dadff0 --- /dev/null +++ b/src/llm_search_quality_evaluation/shared/models/output_format.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class OutputFormat(str, Enum): + """Supported output formats for dataset generation.""" + QUEPID = "quepid" + RRE = "rre" + MTEB = "mteb" + + def __str__(self) -> str: + return self.value + diff --git a/src/llm_search_quality_evaluation/shared/writers/__init__.py b/src/llm_search_quality_evaluation/shared/writers/__init__.py index 851560c..cbb1cac 100644 --- a/src/llm_search_quality_evaluation/shared/writers/__init__.py +++ b/src/llm_search_quality_evaluation/shared/writers/__init__.py @@ -4,6 +4,7 @@ from llm_search_quality_evaluation.shared.writers.rre_writer import RreWriter from llm_search_quality_evaluation.shared.writers.mteb_writer import MtebWriter from llm_search_quality_evaluation.shared.writers.writer_config import WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat __all__ = [ "WriterFactory", @@ -11,5 +12,6 @@ "QuepidWriter", "RreWriter", "MtebWriter", - "WriterConfig" + "WriterConfig", + "OutputFormat" ] diff --git a/src/llm_search_quality_evaluation/shared/writers/writer_config.py b/src/llm_search_quality_evaluation/shared/writers/writer_config.py index 613d9ee..6b3332e 100644 --- a/src/llm_search_quality_evaluation/shared/writers/writer_config.py +++ b/src/llm_search_quality_evaluation/shared/writers/writer_config.py @@ -1,12 +1,13 @@ -from typing import Optional, Literal +from typing import Optional import logging from pydantic import BaseModel, Field +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat log = logging.getLogger(__name__) class WriterConfig(BaseModel): - output_format: Literal['quepid', 'rre', 'mteb'] + output_format: OutputFormat index: str = Field(..., description="Name of the index/collection of the search engine") id_field: Optional[str] = Field(None, description="ID field for the unique key.") query_template: Optional[str] = Field(None, description="Query template for rre evaluator.") diff --git a/src/llm_search_quality_evaluation/shared/writers/writer_factory.py b/src/llm_search_quality_evaluation/shared/writers/writer_factory.py index f6627f6..1560612 100644 --- a/src/llm_search_quality_evaluation/shared/writers/writer_factory.py +++ b/src/llm_search_quality_evaluation/shared/writers/writer_factory.py @@ -3,6 +3,7 @@ from llm_search_quality_evaluation.shared.writers.quepid_writer import QuepidWriter from llm_search_quality_evaluation.shared.writers.rre_writer import RreWriter from llm_search_quality_evaluation.shared.writers.writer_config import WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat from typing import Mapping, Type, TypeAlias import logging @@ -12,15 +13,15 @@ WriterType: TypeAlias = Type[AbstractWriter] class WriterFactory: - OUTPUT_FORMAT_REGISTRY: Mapping[str, WriterType] = { - "quepid": QuepidWriter, - "rre": RreWriter, - "mteb": MtebWriter, + OUTPUT_FORMAT_REGISTRY: Mapping[OutputFormat, WriterType] = { + OutputFormat.QUEPID: QuepidWriter, + OutputFormat.RRE: RreWriter, + OutputFormat.MTEB: MtebWriter, } @classmethod def build(cls, writer_config: WriterConfig) -> AbstractWriter: - output_format: str = writer_config.output_format + output_format: OutputFormat = writer_config.output_format if output_format not in cls.OUTPUT_FORMAT_REGISTRY: log.error(f"Unsupported output format requested: {output_format}") raise ValueError(f"Unsupported output format: {output_format}") diff --git a/src/llm_search_quality_evaluation/vector_search_doctor/approximate_search_evaluator/main.py b/src/llm_search_quality_evaluation/vector_search_doctor/approximate_search_evaluator/main.py index 61a247f..51d3521 100644 --- a/src/llm_search_quality_evaluation/vector_search_doctor/approximate_search_evaluator/main.py +++ b/src/llm_search_quality_evaluation/vector_search_doctor/approximate_search_evaluator/main.py @@ -10,6 +10,7 @@ from llm_search_quality_evaluation.shared.data_store import DataStore from llm_search_quality_evaluation.shared.logger import setup_logging from llm_search_quality_evaluation.shared.writers import WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat from llm_search_quality_evaluation.vector_search_doctor.approximate_search_evaluator.config import Config log = logging.getLogger(__name__) @@ -157,7 +158,7 @@ def main() -> None: id_field=config.id_field, query_template=config.query_template.name, query_placeholder=config.query_placeholder if config.query_placeholder is not None else "$query", - output_format='rre' + output_format=OutputFormat.RRE ) ) writer.write(ratings_folder, data_store) diff --git a/tests/llm_search_quality_evaluation/dataset_generator/test_config_dataset_generator.py b/tests/llm_search_quality_evaluation/dataset_generator/test_config_dataset_generator.py index ab89b88..1e3ba95 100644 --- a/tests/llm_search_quality_evaluation/dataset_generator/test_config_dataset_generator.py +++ b/tests/llm_search_quality_evaluation/dataset_generator/test_config_dataset_generator.py @@ -4,6 +4,7 @@ import pytest from llm_search_quality_evaluation.dataset_generator.config import Config +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat @pytest.fixture @@ -28,7 +29,7 @@ def test_good_config__expects__all_parameters_read(config): assert config.num_queries_needed == 10 assert config.relevance_scale == "graded" assert config.llm_configuration_file == FilePath("tests/resources/llm_config.yaml") - assert config.output_format == "quepid" + assert config.output_format == OutputFormat.QUEPID assert config.output_destination == Path("output") assert config.save_llm_explanation is True assert config.llm_explanation_destination == Path("output/rating_explanation.json") @@ -74,7 +75,7 @@ def test__expects__raises_file_not_found_error(resource_folder): def test_mteb_config__expects__successful_load(resource_folder): file_name = "mteb_config.yaml" mteb_config = Config.load(resource_folder / file_name) - assert mteb_config.output_format == "mteb" + assert mteb_config.output_format == OutputFormat.MTEB assert mteb_config.output_destination == Path("output") def test_missing_both_templates_with_rre__expects__raises_validation_error(resource_folder): diff --git a/tests/llm_search_quality_evaluation/dataset_generator/test_main_autosave.py b/tests/llm_search_quality_evaluation/dataset_generator/test_main_autosave.py index de0bf5c..ddf268c 100644 --- a/tests/llm_search_quality_evaluation/dataset_generator/test_main_autosave.py +++ b/tests/llm_search_quality_evaluation/dataset_generator/test_main_autosave.py @@ -4,6 +4,7 @@ from llm_search_quality_evaluation.dataset_generator import main as main_mod from llm_search_quality_evaluation.dataset_generator.config import Config +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat class DummyWriter: @@ -35,7 +36,7 @@ def test_main_passes_autosave_option_to_datastore(monkeypatch, tmp_path: Path): num_queries_needed=1, relevance_scale="graded", llm_configuration_file=llm_cfg, - output_format="quepid", + output_format=OutputFormat.QUEPID, output_destination=tmp_path, save_llm_explanation=False, llm_explanation_destination=None, diff --git a/tests/llm_search_quality_evaluation/shared/writers/test_mteb_writer.py b/tests/llm_search_quality_evaluation/shared/writers/test_mteb_writer.py index ca610f4..406b966 100644 --- a/tests/llm_search_quality_evaluation/shared/writers/test_mteb_writer.py +++ b/tests/llm_search_quality_evaluation/shared/writers/test_mteb_writer.py @@ -6,13 +6,14 @@ from llm_search_quality_evaluation.shared.data_store import DataStore from llm_search_quality_evaluation.shared.models import Document from llm_search_quality_evaluation.shared.writers.writer_config import WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat from llm_search_quality_evaluation.shared.writers.mteb_writer import MtebWriter @pytest.fixture def writer_config(): return WriterConfig( - output_format='mteb', + output_format=OutputFormat.MTEB, index='testcore' ) diff --git a/tests/llm_search_quality_evaluation/shared/writers/test_quepid_writer.py b/tests/llm_search_quality_evaluation/shared/writers/test_quepid_writer.py index 96f1450..0a076ed 100644 --- a/tests/llm_search_quality_evaluation/shared/writers/test_quepid_writer.py +++ b/tests/llm_search_quality_evaluation/shared/writers/test_quepid_writer.py @@ -4,6 +4,7 @@ from llm_search_quality_evaluation.shared.data_store import DataStore from llm_search_quality_evaluation.shared.writers.writer_config import WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat from llm_search_quality_evaluation.shared.writers.quepid_writer import QuepidWriter, QUEPID_OUTPUT_FILENAME from llm_search_quality_evaluation.shared.models import Document @@ -11,7 +12,7 @@ @pytest.fixture def writer_config(): return WriterConfig( - output_format='quepid', + output_format=OutputFormat.QUEPID, index='testcore' ) diff --git a/tests/llm_search_quality_evaluation/shared/writers/test_rre_writer.py b/tests/llm_search_quality_evaluation/shared/writers/test_rre_writer.py index 1d41414..4d21d63 100644 --- a/tests/llm_search_quality_evaluation/shared/writers/test_rre_writer.py +++ b/tests/llm_search_quality_evaluation/shared/writers/test_rre_writer.py @@ -6,13 +6,14 @@ from llm_search_quality_evaluation.shared.data_store import DataStore from llm_search_quality_evaluation.shared.models import Query, Document from llm_search_quality_evaluation.shared.writers.writer_config import WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat from llm_search_quality_evaluation.shared.writers.rre_writer import RreWriter, RRE_OUTPUT_FILENAME @pytest.fixture def writer_config(): return WriterConfig( - output_format='rre', + output_format=OutputFormat.RRE, index='testcore', id_field='id', query_template='only_q.json', diff --git a/tests/llm_search_quality_evaluation/test_cross_plataform.py b/tests/llm_search_quality_evaluation/test_cross_plataform.py index 13314df..0ce06b5 100644 --- a/tests/llm_search_quality_evaluation/test_cross_plataform.py +++ b/tests/llm_search_quality_evaluation/test_cross_plataform.py @@ -6,6 +6,7 @@ from llm_search_quality_evaluation.shared.data_store import DataStore from llm_search_quality_evaluation.shared.models import Document from llm_search_quality_evaluation.shared.writers.writer_config import WriterConfig +from llm_search_quality_evaluation.shared.models.output_format import OutputFormat from llm_search_quality_evaluation.shared.writers.quepid_writer import QuepidWriter from llm_search_quality_evaluation.vector_search_doctor.embedding_model_evaluator.embedding_writer import EmbeddingWriter from llm_search_quality_evaluation.vector_search_doctor.embedding_model_evaluator.constants import TASKS_NAME_MAPPING @@ -146,7 +147,7 @@ def test_writer_with_special_chars__expects__correctly_handles_specials(tmp_path # Using QuepidWriter as a representative example # TODO: expand tests to the rest of the writers - writer_cfg = WriterConfig(output_format="quepid", index="test") + writer_cfg = WriterConfig(output_format=OutputFormat.QUEPID, index="test") writer = QuepidWriter(writer_cfg) writer.write(tmp_path, ds)