From d217f864d32eae02784d09aaf02c573b38e0d126 Mon Sep 17 00:00:00 2001 From: Kush Gupta Date: Tue, 10 Jun 2025 20:42:42 -0400 Subject: [PATCH 1/6] Add provider options to CLI instead of hardcoded WatsonX (#1) * starting to fix cli for multiple providers Signed-off-by: Kush Gupta * ensure file processing happens before temp is removed Signed-off-by: Kush Gupta * remove verbose comments Signed-off-by: Kush Gupta --------- Signed-off-by: Kush Gupta --- docling_sdg/cli/qa.py | 293 +++++++++++++++++++++++++---------------- docling_sdg/qa/base.py | 19 ++- 2 files changed, 192 insertions(+), 120 deletions(-) diff --git a/docling_sdg/cli/qa.py b/docling_sdg/cli/qa.py index ecf47da..5b35d1d 100644 --- a/docling_sdg/cli/qa.py +++ b/docling_sdg/cli/qa.py @@ -2,16 +2,16 @@ import os import tempfile from pathlib import Path -from typing import Annotated, Any, Optional, Type, Union +from typing import Annotated, Any, Iterable, Optional, Type, Union import typer from dotenv import load_dotenv from llama_index.llms.ibm.base import GenTextParamsMetaNames -from pydantic import AnyUrl, TypeAdapter +from pydantic import AnyUrl, SecretStr, TypeAdapter from rich.console import Console from docling.datamodel.base_models import FormatToExtensions, InputFormat -from docling_core.types.doc import DocItemLabel +from docling_core.types.doc.labels import DocItemLabel from docling_core.types.io import DocumentStream from docling_core.utils.file import resolve_source_to_path @@ -37,7 +37,7 @@ console = Console() err_console = Console(stderr=True) -QaOption = Union[SampleOptions, CritiqueOptions, CritiqueOptions] +QaOption = Union[SampleOptions, GenerateOptions, CritiqueOptions] def get_option_def(field: str, option: Type[QaOption]) -> Any: @@ -56,44 +56,103 @@ def get_option_desc(field: str, option: Type[QaOption]) -> Optional[str]: return field_info.description -def set_watsonx_options(options: LlmOptions) -> None: - if "WATSONX_URL" in os.environ: - options.url = TypeAdapter(AnyUrl).validate_python(os.environ.get("WATSONX_URL")) - if "WATSONX_MODEL_ID" in os.environ: +def set_llm_options_from_env(options: LlmOptions, provider: LlmProvider) -> None: + """Sets LLM options from environment variables based on the provider. + + This function uses the provider to determine the correct environment + variable prefix (e.g., `WATSONX_`, `OPENAI_`, `OPENAI_LIKE_`). + + Args: + options: The options object to modify. + provider: The LLM provider to use. + """ + prefix = provider.name.upper() + + # Generic options applicable to most providers + if f"{prefix}_URL" in os.environ: + options.url = TypeAdapter(AnyUrl).validate_python( + os.environ.get(f"{prefix}_URL") + ) + if f"{prefix}_MODEL_ID" in os.environ: options.model_id = TypeAdapter(str).validate_python( - os.environ.get("WATSONX_MODEL_ID") + os.environ.get(f"{prefix}_MODEL_ID") ) - if "WATSONX_MAX_NEW_TOKENS" in os.environ: + if f"{prefix}_MAX_NEW_TOKENS" in os.environ: options.max_new_tokens = TypeAdapter(int).validate_python( - os.environ.get("WATSONX_MAX_NEW_TOKENS") + os.environ.get(f"{prefix}_MAX_NEW_TOKENS") ) - if "WATSONX_DECODING_METHOD" in os.environ and options.additional_params: - options.additional_params[GenTextParamsMetaNames.DECODING_METHOD] = TypeAdapter( - str - ).validate_python(os.environ.get("WATSONX_DECODING_METHOD")) - if "WATSONX_MIN_NEW_TOKENS" in os.environ and options.additional_params: - options.additional_params[GenTextParamsMetaNames.MIN_NEW_TOKENS] = TypeAdapter( - int - ).validate_python((os.environ.get("WATSONX_MIN_NEW_TOKENS"))) - if "WATSONX_TEMPERATURE" in os.environ and options.additional_params: - options.additional_params[GenTextParamsMetaNames.TEMPERATURE] = TypeAdapter( - float - ).validate_python(os.environ.get("WATSONX_TEMPERATURE")) - if "WATSONX_TOP_K" in os.environ and options.additional_params: - options.additional_params[GenTextParamsMetaNames.TOP_K] = TypeAdapter( - int - ).validate_python((os.environ.get("WATSONX_TOP_K"))) - if "WATSONX_TOP_P" in os.environ and options.additional_params: - options.additional_params[GenTextParamsMetaNames.TOP_P] = TypeAdapter( - float - ).validate_python(os.environ.get("WATSONX_TOP_P")) + + # Provider-specific options (for watsonx) + if provider == LlmProvider.WATSONX and options.additional_params: + if f"{prefix}_DECODING_METHOD" in os.environ: + options.additional_params[ + GenTextParamsMetaNames.DECODING_METHOD + ] = TypeAdapter(str).validate_python( + os.environ.get(f"{prefix}_DECODING_METHOD") + ) + if f"{prefix}_MIN_NEW_TOKENS" in os.environ: + options.additional_params[ + GenTextParamsMetaNames.MIN_NEW_TOKENS + ] = TypeAdapter(int).validate_python( + (os.environ.get(f"{prefix}_MIN_NEW_TOKENS")) + ) + if f"{prefix}_TEMPERATURE" in os.environ: + options.additional_params[ + GenTextParamsMetaNames.TEMPERATURE + ] = TypeAdapter(float).validate_python( + os.environ.get(f"{prefix}_TEMPERATURE") + ) + if f"{prefix}_TOP_K" in os.environ: + options.additional_params[GenTextParamsMetaNames.TOP_K] = TypeAdapter( + int + ).validate_python((os.environ.get(f"{prefix}_TOP_K"))) + if f"{prefix}_TOP_P" in os.environ: + options.additional_params[GenTextParamsMetaNames.TOP_P] = TypeAdapter( + float + ).validate_python(os.environ.get(f"{prefix}_TOP_P")) + + +def _resolve_input_paths( + input_sources: Iterable[str], workdir: Path +) -> list[Union[Path, str, DocumentStream]]: + """Resolves a list of source strings to a list of paths.""" + resolved_paths: list[Union[Path, str, DocumentStream]] = [] + for src in input_sources: + try: + source = resolve_source_to_path(source=src, workdir=workdir) + resolved_paths.append(source) + except FileNotFoundError as err: + err_console.print(f"[red]Error: The input file {src} does not exist.[/red]") + raise typer.Abort() from err + except IsADirectoryError: + try: + local_path = TypeAdapter(Path).validate_python(src) + if local_path.is_dir(): + for fmt in list(InputFormat): + for ext in FormatToExtensions[fmt]: + resolved_paths.extend(list(local_path.glob(f"**/*.{ext}"))) + resolved_paths.extend( + list(local_path.glob(f"**/*.{ext.upper()}")) + ) + elif local_path.exists(): + resolved_paths.append(local_path) + else: + err_console.print( + f"[red]Error: The input file {src} does not exist.[/red]" + ) + raise typer.Abort() + except Exception as err: + err_console.print(f"[red]Error: Cannot read the input {src}.[/red]") + _log.info(err) + raise typer.Abort() from err + return resolved_paths @app.command( no_args_is_help=True, help=( "Prepare the data for SDG: parse and chunk documents to create a file " - "with document passsages." + "with document passages." ), ) def sample( @@ -103,8 +162,8 @@ def sample( ..., metavar="source", help=( - "PDF files to convert, chunk, and sample. Can be local file / " - "directory paths or URL." + "PDF files to convert, chunk, and sample. Can be a local file, " + "directory path, or URL." ), ), ], @@ -114,9 +173,7 @@ def sample( "--verbose", "-v", count=True, - help=( - "Set the verbosity level. -v for info logging, -vv for debug logging." - ), + help="Set the verbosity level. -v for info logging, -vv for debug logging.", ), ] = 0, sample_file: Annotated[ @@ -133,6 +190,7 @@ def sample( "--chunker", "-c", help=get_option_desc("chunker", SampleOptions), + case_sensitive=False, ), ] = get_option_def("chunker", SampleOptions), min_token_count: Annotated[ @@ -157,6 +215,7 @@ def sample( "--doc-items", "-d", help=get_option_desc("doc_items", SampleOptions), + case_sensitive=False, ), ] = get_option_def("doc_items", SampleOptions), seed: Annotated[ @@ -176,61 +235,33 @@ def sample( logging.basicConfig(level=logging.DEBUG) with tempfile.TemporaryDirectory() as tempdir: - input_doc_paths: list[Path | str | DocumentStream] = [] - for src in input_sources: - try: - # check if we can fetch some remote url - source = resolve_source_to_path(source=src, workdir=Path(tempdir)) - input_doc_paths.append(source) - except FileNotFoundError as err: - err_console.print( - f"[red]Error: The input file {src} does not exist.[/red]" - ) - raise typer.Abort() from err - except IsADirectoryError: - # if the input matches to a file or a folder - try: - local_path = TypeAdapter(Path).validate_python(src) - if local_path.exists() and local_path.is_dir(): - for fmt in list(InputFormat): - for ext in FormatToExtensions[fmt]: - input_doc_paths.extend( - list(local_path.glob(f"**/*.{ext}")) - ) - input_doc_paths.extend( - list(local_path.glob(f"**/*.{ext.upper()}")) - ) - elif local_path.exists(): - input_doc_paths.append(local_path) - else: - err_console.print( - f"[red]Error: The input file {src} does not exist.[/red]" - ) - raise typer.Abort() - except Exception as err: - err_console.print(f"[red]Error: Cannot read the input {src}.[/red]") - _log.info(err) # will print more details if verbose is activated - raise typer.Abort() from err - - options: SampleOptions = SampleOptions( - sample_file=sample_file, - chunker=chunker, - min_token_count=min_token_count, - max_passages=max_passages, - doc_items=doc_items, - seed=seed, - ) - sample = PassageSampler(sample_options=options) - result: SampleResult = sample.sample(input_doc_paths) + input_doc_paths = _resolve_input_paths(input_sources, Path(tempdir)) + + # Build the options dictionary conditionally to handle optional CLI args + options_dict: dict[str, Any] = {} + if sample_file is not None: + options_dict["sample_file"] = sample_file + if chunker is not None: + options_dict["chunker"] = chunker + if min_token_count is not None: + options_dict["min_token_count"] = min_token_count + if max_passages is not None: + options_dict["max_passages"] = max_passages + if doc_items is not None: + options_dict["doc_items"] = doc_items + if seed is not None: + options_dict["seed"] = seed + + options = SampleOptions(**options_dict) + + passage_sampler = PassageSampler(sample_options=options) + result: SampleResult = passage_sampler.sample(input_doc_paths) typer.echo(f"Q&A Sample finished: {result}") @app.command( no_args_is_help=True, - help=( - "Run SDG on a set of document passages and create question-answering items " - "of different types." - ), + help="Run SDG on a set of document passages and create Q&A items.", ) def generate( input_source: Annotated[ @@ -255,23 +286,32 @@ def generate( typer.Option( "--generated-file", "-f", - help=get_option_desc("generated_file", CritiqueOptions), + help=get_option_desc("generated_file", GenerateOptions), ), - ] = get_option_def("generated_file", CritiqueOptions), + ] = get_option_def("generated_file", GenerateOptions), max_qac: Annotated[ Optional[int], typer.Option( "--max-qac", "-q", - help=get_option_desc("max_qac", CritiqueOptions), + help=get_option_desc("max_qac", GenerateOptions), ), - ] = get_option_def("max_qac", CritiqueOptions), - watsonx: Annotated[ + ] = get_option_def("max_qac", GenerateOptions), + provider: Annotated[ + LlmProvider, + typer.Option( + "--provider", + "-p", + help="The LLM provider to use for generation.", + case_sensitive=False, + ), + ] = LlmProvider.WATSONX, + env_file: Annotated[ Optional[Path], typer.Option( - "--watsonx", - "-w", - help="Path to a file with the parameters for watsonx.ai.", + "--env-file", + "-e", + help="Path to a file with environment variables for the LLM provider.", ), ] = Path("./.env"), ) -> None: @@ -288,21 +328,30 @@ def generate( ) raise typer.Abort() - if not watsonx or not os.path.isfile(watsonx): + if not env_file or not os.path.isfile(env_file): err_console.print( - f"[red]Error: The watsonx.ai file {watsonx} does not exist.[/red]" + f"[red]Error: The environment file {env_file} does not exist.[/red]" ) raise typer.Abort() - load_dotenv(watsonx) + load_dotenv(env_file) + + prefix = provider.name.upper() + api_key_env_var = f"{prefix}_APIKEY" + project_id_env_var = f"{prefix}_PROJECT_ID" + + api_key_str = os.environ.get(api_key_env_var) + project_id_str = os.environ.get(project_id_env_var) options = GenerateOptions( - provider=LlmProvider.WATSONX, - project_id=os.environ.get("WATSONX_PROJECT_ID"), - api_key=os.environ.get("WATSONX_APIKEY"), + provider=provider, + project_id=SecretStr(project_id_str) + if provider == LlmProvider.WATSONX and project_id_str + else None, + api_key=SecretStr(api_key_str) if api_key_str else None, ) - set_watsonx_options(options) + set_llm_options_from_env(options, provider) if generated_file: options.generated_file = generated_file if max_qac: @@ -351,12 +400,21 @@ def critique( help=get_option_desc("max_qac", CritiqueOptions), ), ] = get_option_def("max_qac", CritiqueOptions), - watsonx: Annotated[ + provider: Annotated[ + LlmProvider, + typer.Option( + "--provider", + "-P", + help="The LLM provider to use for critique.", + case_sensitive=False, + ), + ] = LlmProvider.WATSONX, + env_file: Annotated[ Optional[Path], typer.Option( - "--watsonx", - "-w", - help="Path to a file with the parameters for watsonx.ai.", + "--env-file", + "-e", + help="Path to a file with environment variables for the LLM provider.", ), ] = Path("./.env"), ) -> None: @@ -373,21 +431,30 @@ def critique( ) raise typer.Abort() - if not watsonx or not os.path.isfile(watsonx): + if not env_file or not os.path.isfile(env_file): err_console.print( - f"[red]Error: The watsonx.ai file {watsonx} does not exist.[/red]" + f"[red]Error: The environment file {env_file} does not exist.[/red]" ) raise typer.Abort() - load_dotenv(watsonx) + load_dotenv(env_file) + + prefix = provider.name.upper() + api_key_env_var = f"{prefix}_APIKEY" + project_id_env_var = f"{prefix}_PROJECT_ID" + + api_key_str = os.environ.get(api_key_env_var) + project_id_str = os.environ.get(project_id_env_var) options = CritiqueOptions( - provider=LlmProvider.WATSONX, - project_id=os.environ.get("WATSONX_PROJECT_ID"), - api_key=os.environ.get("WATSONX_APIKEY"), + provider=provider, + project_id=SecretStr(project_id_str) + if provider == LlmProvider.WATSONX and project_id_str + else None, + api_key=SecretStr(api_key_str) if api_key_str else None, ) - set_watsonx_options(options) + set_llm_options_from_env(options, provider) if critiqued_file: options.critiqued_file = critiqued_file if max_qac: diff --git a/docling_sdg/qa/base.py b/docling_sdg/qa/base.py index f83ecf8..50bb878 100644 --- a/docling_sdg/qa/base.py +++ b/docling_sdg/qa/base.py @@ -13,8 +13,8 @@ SecretStr, ) -from docling_core.transforms.chunker import DocChunk, DocMeta -from docling_core.types.doc import DocItemLabel +from docling_core.transforms.chunker.hierarchical_chunker import DocChunk, DocMeta +from docling_core.types.doc.labels import DocItemLabel from docling_core.types.nlp.qa import QAPair from docling_sdg.qa.prompts.critique_prompts import ( @@ -85,18 +85,21 @@ class LlmOptions(BaseModel): ) url: AnyUrl = Field( default=AnyUrl("http://127.0.0.1:11434/v1"), - description="Url to LLM API endpoint", + description="URL to the LLM API endpoint.", ) project_id: Optional[SecretStr] = Field( - default=None, description="ID of the Watson Studio project." + default=None, + description=( + "Project ID for the LLM provider (if applicable, e.g., watsonx.ai)." + ), ) api_key: Optional[SecretStr] = Field( default=None, - description="API key to Watson Machine Learning or CPD instance.", + description="API key for the LLM provider.", ) model_id: str = Field( default="mistralai/mixtral-8x7b-instruct-v01", - description="Which model to use.", + description="The model ID to use for generation.", ) max_new_tokens: int = Field( default=512, ge=0, description="The maximum number of tokens to generate." @@ -109,7 +112,9 @@ class LlmOptions(BaseModel): GenTextParamsMetaNames.TOP_K: 50, GenTextParamsMetaNames.TOP_P: 0.95, }, - description="Additional generation params for the watsonx.ai models.", + description=( + "Additional generation parameters for the LLM (e.g., for watsonx.ai)." + ), ) From 971d997060a9b275e3054f76af2c20040fa6281b Mon Sep 17 00:00:00 2001 From: Kush Gupta Date: Tue, 10 Jun 2025 21:09:56 -0400 Subject: [PATCH 2/6] ruff-format Signed-off-by: Kush Gupta --- docling_sdg/cli/qa.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/docling_sdg/cli/qa.py b/docling_sdg/cli/qa.py index 5b35d1d..c606c2e 100644 --- a/docling_sdg/cli/qa.py +++ b/docling_sdg/cli/qa.py @@ -85,23 +85,21 @@ def set_llm_options_from_env(options: LlmOptions, provider: LlmProvider) -> None # Provider-specific options (for watsonx) if provider == LlmProvider.WATSONX and options.additional_params: if f"{prefix}_DECODING_METHOD" in os.environ: - options.additional_params[ - GenTextParamsMetaNames.DECODING_METHOD - ] = TypeAdapter(str).validate_python( - os.environ.get(f"{prefix}_DECODING_METHOD") + options.additional_params[GenTextParamsMetaNames.DECODING_METHOD] = ( + TypeAdapter(str).validate_python( + os.environ.get(f"{prefix}_DECODING_METHOD") + ) ) if f"{prefix}_MIN_NEW_TOKENS" in os.environ: - options.additional_params[ - GenTextParamsMetaNames.MIN_NEW_TOKENS - ] = TypeAdapter(int).validate_python( - (os.environ.get(f"{prefix}_MIN_NEW_TOKENS")) + options.additional_params[GenTextParamsMetaNames.MIN_NEW_TOKENS] = ( + TypeAdapter(int).validate_python( + (os.environ.get(f"{prefix}_MIN_NEW_TOKENS")) + ) ) if f"{prefix}_TEMPERATURE" in os.environ: - options.additional_params[ - GenTextParamsMetaNames.TEMPERATURE - ] = TypeAdapter(float).validate_python( - os.environ.get(f"{prefix}_TEMPERATURE") - ) + options.additional_params[GenTextParamsMetaNames.TEMPERATURE] = TypeAdapter( + float + ).validate_python(os.environ.get(f"{prefix}_TEMPERATURE")) if f"{prefix}_TOP_K" in os.environ: options.additional_params[GenTextParamsMetaNames.TOP_K] = TypeAdapter( int From 61aac9050472f8f56df2d6489a45aa7b2ed985bd Mon Sep 17 00:00:00 2001 From: Kush Gupta Date: Fri, 13 Jun 2025 10:59:07 -0400 Subject: [PATCH 3/6] Q&A consistency Signed-off-by: Kush Gupta --- docling_sdg/cli/main.py | 2 +- docling_sdg/cli/qa.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docling_sdg/cli/main.py b/docling_sdg/cli/main.py index 7107812..93948a4 100644 --- a/docling_sdg/cli/main.py +++ b/docling_sdg/cli/main.py @@ -35,4 +35,4 @@ def get_version() -> None: typer.echo(f"Platform: {platform_str}") -app.add_typer(qa_app, name="qa", help="Interact with SDG for question-answering.") +app.add_typer(qa_app, name="qa", help="Interact with SDG for Q&A.") diff --git a/docling_sdg/cli/qa.py b/docling_sdg/cli/qa.py index c606c2e..7f0017c 100644 --- a/docling_sdg/cli/qa.py +++ b/docling_sdg/cli/qa.py @@ -362,7 +362,7 @@ def generate( @app.command( no_args_is_help=True, - help="Use LLM as a judge to critique a set of SDG question-answering items.", + help="Use LLM as a judge to critique a set of SDG Q&A items.", ) def critique( input_source: Annotated[ From 8c916c621ea590fc01451a8c2e2083aa3ca9154b Mon Sep 17 00:00:00 2001 From: Kush Gupta Date: Fri, 13 Jun 2025 15:31:38 -0400 Subject: [PATCH 4/6] tests and revert import changes Signed-off-by: Kush Gupta --- docling_sdg/cli/qa.py | 2 +- docling_sdg/qa/base.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docling_sdg/cli/qa.py b/docling_sdg/cli/qa.py index 7f0017c..2ce1e9a 100644 --- a/docling_sdg/cli/qa.py +++ b/docling_sdg/cli/qa.py @@ -11,7 +11,7 @@ from rich.console import Console from docling.datamodel.base_models import FormatToExtensions, InputFormat -from docling_core.types.doc.labels import DocItemLabel +from docling_core.types.doc import DocItemLabel from docling_core.types.io import DocumentStream from docling_core.utils.file import resolve_source_to_path diff --git a/docling_sdg/qa/base.py b/docling_sdg/qa/base.py index 50bb878..c2e58ee 100644 --- a/docling_sdg/qa/base.py +++ b/docling_sdg/qa/base.py @@ -13,8 +13,8 @@ SecretStr, ) -from docling_core.transforms.chunker.hierarchical_chunker import DocChunk, DocMeta -from docling_core.types.doc.labels import DocItemLabel +from docling_core.transforms.chunker import DocChunk, DocMeta +from docling_core.types.doc import DocItemLabel from docling_core.types.nlp.qa import QAPair from docling_sdg.qa.prompts.critique_prompts import ( From 1fd560725d2685abd0264f4195defde1c726fe10 Mon Sep 17 00:00:00 2001 From: Kush Gupta Date: Fri, 13 Jun 2025 15:37:28 -0400 Subject: [PATCH 5/6] test file for agnostic CLI Signed-off-by: Kush Gupta --- tests/test_qa_cli.py | 880 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 880 insertions(+) create mode 100644 tests/test_qa_cli.py diff --git a/tests/test_qa_cli.py b/tests/test_qa_cli.py new file mode 100644 index 0000000..3ef4335 --- /dev/null +++ b/tests/test_qa_cli.py @@ -0,0 +1,880 @@ +import logging +import os +import tempfile +from pathlib import Path +from typing import Dict, List +from unittest import mock + +import pytest +from llama_index.llms.ibm.base import GenTextParamsMetaNames +from pydantic import AnyUrl, SecretStr +from typer import Abort +from typer.testing import CliRunner + +from docling_sdg.cli.qa import _resolve_input_paths, app, set_llm_options_from_env +from docling_sdg.qa.base import LlmOptions, LlmProvider + +runner = CliRunner() + +# Assisted by: Jules (Gemini 2.5 pro) + +def test_app_help() -> None: + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "sample" in result.stdout + assert "generate" in result.stdout + assert "critique" in result.stdout + + +def test_resolve_input_paths_single_file() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + test_file_name = "test_file.txt" + test_file_abs = workdir / test_file_name + test_file_abs.touch() + + input_file_str = str(test_file_abs) + + def mock_resolve_path_side_effect(source: str, workdir: Path) -> Path: + path_source = Path(source) + if path_source.exists(): + return path_source + raise FileNotFoundError + + with mock.patch( + "docling_sdg.cli.qa.resolve_source_to_path", + side_effect=mock_resolve_path_side_effect, + ) as mock_resolve: + resolved = _resolve_input_paths([input_file_str], workdir) + mock_resolve.assert_called_once_with(source=input_file_str, workdir=workdir) + assert len(resolved) == 1 + assert isinstance(resolved[0], Path) + assert resolved[0] == test_file_abs + + +def test_resolve_input_paths_multiple_files() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + test_file1_abs = workdir / "test_file1.txt" + test_file1_abs.touch() + test_file2_abs = workdir / "test_file2.pdf" + test_file2_abs.touch() + + input_file1_str = str(test_file1_abs) + input_file2_str = str(test_file2_abs) + + def mock_resolve_side_effect(source: str, workdir: Path) -> Path: + path_source = Path(source) + if path_source.exists() and path_source.is_file(): + return path_source + raise FileNotFoundError("File not found by mock_resolve_side_effect") + + with mock.patch( + "docling_sdg.cli.qa.resolve_source_to_path", + side_effect=mock_resolve_side_effect, + ) as mock_resolve: + resolved = _resolve_input_paths([input_file1_str, input_file2_str], workdir) + assert mock_resolve.call_count == 2 + calls = [ + mock.call(source=input_file1_str, workdir=workdir), + mock.call(source=input_file2_str, workdir=workdir), + ] + mock_resolve.assert_has_calls(calls, any_order=True) + assert len(resolved) == 2 + assert set(resolved) == {test_file1_abs, test_file2_abs} + + +def test_resolve_input_paths_directory() -> None: + with tempfile.TemporaryDirectory() as tmpdir_str: + workdir = Path(tmpdir_str) + + dir_name_inside_workdir = "my_test_dir" + actual_test_dir_abs = workdir / dir_name_inside_workdir + actual_test_dir_abs.mkdir() + + (actual_test_dir_abs / "file1.txt").touch() + (actual_test_dir_abs / "file2.pdf").touch() + (actual_test_dir_abs / "other.doc").touch() + sub_dir_abs = actual_test_dir_abs / "sub" + sub_dir_abs.mkdir() + (sub_dir_abs / "file3.json").touch() + + input_dir_abs_str = str(actual_test_dir_abs) + + text_format_val: str = "text" + pdf_format_val: str = "pdf" + json_format_val: str = "json" + + mocked_format_to_extensions: Dict[str, List[str]] = { + text_format_val: ["txt"], + pdf_format_val: ["pdf"], + json_format_val: ["json"], + } + + def resolve_side_effect_for_dir(source: str, workdir: Path) -> None: + if source == input_dir_abs_str: + raise IsADirectoryError(f"'{source}' is a directory.") + raise ValueError( + f"resolve_side_effect_for_dir called with unexpected source: {source}" + ) + + with ( + mock.patch( + "docling.datamodel.base_models.FormatToExtensions", + new=mocked_format_to_extensions, + ), + mock.patch( + "docling_sdg.cli.qa.resolve_source_to_path", + side_effect=resolve_side_effect_for_dir, + ) as mock_resolve_source, + ): + resolved = _resolve_input_paths([input_dir_abs_str], workdir) + + mock_resolve_source.assert_called_once_with( + source=input_dir_abs_str, workdir=workdir + ) + + assert len(resolved) == 3 + expected_paths = { + actual_test_dir_abs / "file1.txt", + actual_test_dir_abs / "file2.pdf", + sub_dir_abs / "file3.json", + } + assert set(resolved) == expected_paths + + +def test_resolve_input_paths_non_existent_file() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + non_existent_file_abs_str = str(workdir / "non_existent_file.txt") + + with mock.patch( + "docling_sdg.cli.qa.resolve_source_to_path", + side_effect=FileNotFoundError("File not found"), + ) as mock_resolve: + with pytest.raises(Abort): + _resolve_input_paths([non_existent_file_abs_str], workdir) + mock_resolve.assert_called_once_with( + source=non_existent_file_abs_str, workdir=workdir + ) + + +def test_resolve_input_paths_non_existent_dir_as_input() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + input_non_existent_dir_abs_str = str(workdir / "non_existent_dir") + + def resolve_mock_side_effect(source: str, workdir: Path) -> None: + if source == input_non_existent_dir_abs_str: + raise IsADirectoryError(f"{source} is (allegedly) a directory.") + raise FileNotFoundError( + f"File not found by resolve_mock_side_effect for {source}" + ) + + with mock.patch( + "docling_sdg.cli.qa.resolve_source_to_path", + side_effect=resolve_mock_side_effect, + ) as mock_resolve: + with pytest.raises(Abort) as excinfo: + _resolve_input_paths([input_non_existent_dir_abs_str], workdir) + + assert excinfo.type is Abort + + mock_resolve.assert_called_once_with( + source=input_non_existent_dir_abs_str, workdir=workdir + ) + + +@mock.patch("docling_sdg.cli.qa.resolve_source_to_path") +def test_resolve_input_paths_url(mock_resolve_source_path_for_url: mock.Mock) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + mock_url_content_path = workdir / "downloaded_file.pdf" + + mock_resolve_source_path_for_url.return_value = mock_url_content_path + + input_url = "http://example.com/somefile.pdf" + resolved = _resolve_input_paths([input_url], workdir) + + mock_resolve_source_path_for_url.assert_called_once_with( + source=input_url, workdir=workdir + ) + assert len(resolved) == 1 + assert resolved[0] == mock_url_content_path + + +@mock.patch( + "docling_sdg.cli.qa.resolve_source_to_path", + side_effect=FileNotFoundError("Mocked URL FileNotFoundError"), +) +def test_resolve_input_paths_url_not_found( + mock_resolve_source_error_for_url: mock.Mock, +) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + input_url = "http://example.com/non_existent.pdf" + with pytest.raises(Abort): + _resolve_input_paths([input_url], workdir) + mock_resolve_source_error_for_url.assert_called_once_with( + source=input_url, workdir=workdir + ) + + +@mock.patch("docling_sdg.cli.qa.resolve_source_to_path") +def test_resolve_input_paths_mixed_sources( + mock_resolve_source_mixed_case: mock.Mock, +) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + + actual_local_file_abs = workdir / "local_file.txt" + actual_local_file_abs.touch() + local_file_input_str = str(actual_local_file_abs) + + url_input_str = "http://example.com/file.pdf" + mock_url_content_path = workdir / "downloaded_url_file.pdf" + + def side_effect_for_mixed(source: str, workdir: Path) -> Path: + if source == url_input_str: + return mock_url_content_path + elif source == local_file_input_str: + path_obj = Path(source) + if path_obj.exists(): + return path_obj + else: + raise FileNotFoundError( + f"Mocked local (abs str): {source} not found by side_effect" + ) + else: + raise ValueError( + f"Unexpected source in mock side_effect_for_mixed: {source}" + ) + + mock_resolve_source_mixed_case.side_effect = side_effect_for_mixed + + input_sources = [local_file_input_str, url_input_str] + resolved = _resolve_input_paths(input_sources, workdir) + + assert len(resolved) == 2 + expected_resolved_paths = {actual_local_file_abs, mock_url_content_path} + assert set(resolved) == expected_resolved_paths + + expected_calls = [ + mock.call(source=local_file_input_str, workdir=workdir), + mock.call(source=url_input_str, workdir=workdir), + ] + mock_resolve_source_mixed_case.assert_has_calls(expected_calls, any_order=True) + + +# Tests for set_llm_options_from_env +def test_set_llm_options_from_env_generic_options_watsonx() -> None: + options = LlmOptions(api_key=SecretStr("dummy_key")) + provider = LlmProvider.WATSONX + env_vars = { + "WATSONX_URL": "http://watsonx.example.com", + "WATSONX_MODEL_ID": "test_model_123", + "WATSONX_MAX_NEW_TOKENS": "150", + } + with mock.patch.dict(os.environ, env_vars): + set_llm_options_from_env(options, provider) + + assert options.url == AnyUrl("http://watsonx.example.com") + assert options.model_id == "test_model_123" + assert options.max_new_tokens == 150 + + +def test_set_llm_options_from_env_generic_options_openai() -> None: + options = LlmOptions(api_key=SecretStr("dummy_key")) + provider = LlmProvider.OPENAI + env_vars = { + "OPENAI_URL": "http://openai.example.com", + "OPENAI_MODEL_ID": "gpt-4", + "OPENAI_MAX_NEW_TOKENS": "200", + } + with mock.patch.dict(os.environ, env_vars): + set_llm_options_from_env(options, provider) + + assert options.url == AnyUrl("http://openai.example.com") + assert options.model_id == "gpt-4" + assert options.max_new_tokens == 200 + + +def test_set_llm_options_from_env_watsonx_specific_params() -> None: + options = LlmOptions(api_key=SecretStr("dummy_key")) + provider = LlmProvider.WATSONX + env_vars = { + "WATSONX_URL": "http://watsonx.example.com", + "WATSONX_MODEL_ID": "test_model_123", + "WATSONX_DECODING_METHOD": "greedy", + "WATSONX_MIN_NEW_TOKENS": "10", + "WATSONX_TEMPERATURE": "0.7", + "WATSONX_TOP_K": "50", + "WATSONX_TOP_P": "0.9", + } + with mock.patch.dict(os.environ, env_vars): + set_llm_options_from_env(options, provider) + + assert options.additional_params[GenTextParamsMetaNames.DECODING_METHOD] == "greedy" + assert options.additional_params[GenTextParamsMetaNames.MIN_NEW_TOKENS] == 10 + assert options.additional_params[GenTextParamsMetaNames.TEMPERATURE] == 0.7 + assert options.additional_params[GenTextParamsMetaNames.TOP_K] == 50 + assert options.additional_params[GenTextParamsMetaNames.TOP_P] == 0.9 + + +def test_llm_options_from_env_watsonx_specific_params_no_init_additional_params() -> ( + None +): + options = LlmOptions(api_key=SecretStr("dummy_key")) + provider = LlmProvider.WATSONX + env_vars = { + "WATSONX_DECODING_METHOD": "sample", + } + with mock.patch.dict(os.environ, env_vars): + set_llm_options_from_env(options, provider) + + assert options.additional_params is not None + assert options.additional_params[GenTextParamsMetaNames.DECODING_METHOD] == "sample" + + +def test_set_llm_options_from_env_no_env_vars_set() -> None: + options = LlmOptions( + api_key=SecretStr("dummy_key"), url=AnyUrl("http://default.url") + ) + original_url = options.url + original_model_id = options.model_id + original_max_new_tokens = options.max_new_tokens + + provider = LlmProvider.OPENAI_LIKE + + with mock.patch.dict(os.environ, {}, clear=True): + set_llm_options_from_env(options, provider) + + assert options.url == original_url + assert options.model_id == original_model_id + assert options.max_new_tokens == original_max_new_tokens + if provider == LlmProvider.WATSONX: + assert options.additional_params is None + + +def test_set_llm_options_from_env_partial_env_vars() -> None: + options = LlmOptions(api_key=SecretStr("dummy_key")) + provider = LlmProvider.WATSONX + env_vars = { + "WATSONX_URL": "http://partial.example.com", + # MODEL_ID not set + "WATSONX_MAX_NEW_TOKENS": "50", + # This requires additional_params to be not None + "WATSONX_DECODING_METHOD": "greedy", + } + options.additional_params = {} + + with mock.patch.dict(os.environ, env_vars, clear=True): + set_llm_options_from_env(options, provider) + + assert options.url == AnyUrl("http://partial.example.com") + assert options.model_id == "mistralai/mixtral-8x7b-instruct-v01" + assert options.max_new_tokens == 50 + assert GenTextParamsMetaNames.DECODING_METHOD not in options.additional_params + assert GenTextParamsMetaNames.TEMPERATURE not in options.additional_params + + +# Tests for `sample` CLI command +@mock.patch("docling_sdg.cli.qa.PassageSampler") +@mock.patch("docling_sdg.cli.qa._resolve_input_paths") +def test_sample_command_single_file( + mock_resolve_paths: mock.Mock, mock_passage_sampler_cls: mock.Mock +) -> None: + runner = CliRunner() + mock_sampler_instance = mock.Mock() + mock_passage_sampler_cls.return_value = mock_sampler_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + test_file = workdir / "input1.pdf" + test_file.touch() + + mock_resolve_paths.return_value = [test_file] + + result = runner.invoke(app, ["sample", str(test_file)]) + + assert result.exit_code == 0 + mock_resolve_paths.assert_called_once() + mock_passage_sampler_cls.assert_called_once() + mock_sampler_instance.sample.assert_called_once_with([test_file]) + assert "Q&A Sample finished" in result.stdout + + +@mock.patch("docling_sdg.cli.qa.PassageSampler") +@mock.patch("docling_sdg.cli.qa._resolve_input_paths") +def test_sample_command_multiple_files( + mock_resolve_paths: mock.Mock, mock_passage_sampler_cls: mock.Mock +) -> None: + runner = CliRunner() + mock_sampler_instance = mock.Mock() + mock_passage_sampler_cls.return_value = mock_sampler_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + file1 = workdir / "doc1.pdf" + file1.touch() + file2 = workdir / "doc2.txt" + file2.touch() + + resolved_files = [file1, file2] + mock_resolve_paths.return_value = resolved_files + + result = runner.invoke(app, ["sample", str(file1), str(file2)]) + + assert result.exit_code == 0 + mock_resolve_paths.assert_called_once() + mock_passage_sampler_cls.assert_called_once() + mock_sampler_instance.sample.assert_called_once_with(resolved_files) + + +@mock.patch("docling_sdg.cli.qa.PassageSampler") +@mock.patch("docling_sdg.cli.qa._resolve_input_paths") +def test_sample_command_with_options( + mock_resolve_paths: mock.Mock, mock_passage_sampler_cls: mock.Mock +) -> None: + runner = CliRunner() + mock_sampler_instance = mock.Mock() + mock_passage_sampler_cls.return_value = mock_sampler_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + test_file = workdir / "test.pdf" + test_file.touch() + sample_out_file = workdir / "samples.jsonl" + + mock_resolve_paths.return_value = [test_file] + + result = runner.invoke( + app, + [ + "sample", + str(test_file), + "--sample-file", + str(sample_out_file), + "--chunker", + "hybrid", + "--min-token-count", + "50", + "--max-passages", + "100", + "--doc-items", + "picture", + "--doc-items", + "table", + "--seed", + "42", + ], + ) + + assert result.exit_code == 0 + mock_resolve_paths.assert_called_once() + + args, kwargs = mock_passage_sampler_cls.call_args + assert "sample_options" in kwargs + options_passed = kwargs["sample_options"] + + assert options_passed.sample_file == sample_out_file + + assert options_passed.chunker == "hybrid" + + assert options_passed.min_token_count == 50 + assert options_passed.max_passages == 100 + + assert "picture" in options_passed.doc_items + assert "table" in options_passed.doc_items + + assert options_passed.seed == 42 + + mock_sampler_instance.sample.assert_called_once_with([test_file]) + + +@mock.patch("docling_sdg.cli.qa.PassageSampler") +@mock.patch("docling_sdg.cli.qa._resolve_input_paths") +def test_sample_command_input_file_not_exist( + mock_resolve_paths: mock.Mock, mock_passage_sampler_cls: mock.Mock +) -> None: + runner = CliRunner() + + mock_resolve_paths.side_effect = Abort() + + result = runner.invoke(app, ["sample", "nonexistent.pdf"]) + + assert result.exit_code != 0 + + mock_resolve_paths.assert_called_once_with(["nonexistent.pdf"], mock.ANY) + + mock_passage_sampler_cls.assert_not_called() + + +@mock.patch("docling_sdg.cli.qa.logging.basicConfig") +@mock.patch("docling_sdg.cli.qa.PassageSampler") +@mock.patch("docling_sdg.cli.qa._resolve_input_paths") +def test_sample_command_verbosity_v( + mock_resolve_paths: mock.Mock, + mock_passage_sampler_cls: mock.Mock, + mock_log_config: mock.Mock, +) -> None: + runner = CliRunner() + mock_sampler_instance = mock.Mock() + mock_passage_sampler_cls.return_value = mock_sampler_instance + + with tempfile.TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "logtest.pdf" + test_file.touch() + mock_resolve_paths.return_value = [test_file] + + result = runner.invoke(app, ["sample", str(test_file), "-v"]) + assert result.exit_code == 0 + + mock_log_config.assert_any_call(level=logging.INFO) + + +@mock.patch("docling_sdg.cli.qa.logging.basicConfig") +@mock.patch("docling_sdg.cli.qa.PassageSampler") +@mock.patch("docling_sdg.cli.qa._resolve_input_paths") +def test_sample_command_verbosity_vv( + mock_resolve_paths: mock.Mock, + mock_passage_sampler_cls: mock.Mock, + mock_log_config: mock.Mock, +) -> None: + runner = CliRunner() + mock_sampler_instance = mock.Mock() + mock_passage_sampler_cls.return_value = mock_sampler_instance + + with tempfile.TemporaryDirectory() as tmpdir: + test_file = Path(tmpdir) / "logtest_debug.pdf" + test_file.touch() + mock_resolve_paths.return_value = [test_file] + + result = runner.invoke(app, ["sample", str(test_file), "-vv"]) + assert result.exit_code == 0 + + mock_log_config.assert_any_call(level=logging.DEBUG) + + +# Tests for `generate` CLI command +@mock.patch("docling_sdg.cli.qa.load_dotenv") +@mock.patch("docling_sdg.cli.qa.set_llm_options_from_env") +@mock.patch("docling_sdg.cli.qa.Generator") +def test_generate_command_valid_input( + mock_generator_cls: mock.Mock, + mock_set_llm_opts: mock.Mock, + mock_load_dotenv: mock.Mock, +) -> None: + runner = CliRunner() + mock_generator_instance = mock.Mock() + mock_generator_cls.return_value = mock_generator_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + sample_input_file = workdir / "sample_passages.jsonl" + sample_input_file.touch() + + env_file = workdir / ".env" + env_file.touch() + + env_vars_for_llm = { + "WATSONX_APIKEY": "testapikey", + "WATSONX_PROJECT_ID": "testprojectid", + } + + with mock.patch.dict(os.environ, env_vars_for_llm): + result = runner.invoke( + app, + ["generate", str(sample_input_file), "--env-file", str(env_file)], + ) + + assert result.exit_code == 0, f"CLI Error: {result.stdout}" + mock_load_dotenv.assert_called_once_with(env_file) + mock_generator_cls.assert_called_once() + + mock_set_llm_opts.assert_called_once() + + mock_generator_instance.generate_from_sample.assert_called_once_with( + sample_input_file + ) + assert "Q&A Generation finished" in result.stdout + + +@mock.patch("docling_sdg.cli.qa.load_dotenv") +@mock.patch("docling_sdg.cli.qa.set_llm_options_from_env") +@mock.patch("docling_sdg.cli.qa.Generator") +def test_generate_command_options_and_provider( + mock_generator_cls: mock.Mock, + mock_set_llm_opts: mock.Mock, + mock_load_dotenv: mock.Mock, +) -> None: + runner = CliRunner() + mock_generator_instance = mock.Mock() + mock_generator_cls.return_value = mock_generator_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + sample_input_file = workdir / "passages.jsonl" + sample_input_file.touch() + generated_output_file = workdir / "generated_qna.jsonl" + env_file = workdir / "custom.env" + env_file.touch() + + env_vars_for_llm = {"OPENAI_APIKEY": "openaikey"} + + with mock.patch.dict(os.environ, env_vars_for_llm): + result = runner.invoke( + app, + [ + "generate", + str(sample_input_file), + "--generated-file", + str(generated_output_file), + "--max-qac", + "50", + "--provider", + "OPENAI", + "--env-file", + str(env_file), + ], + ) + + assert result.exit_code == 0, f"CLI Error: {result.stdout}" + mock_load_dotenv.assert_called_once_with(env_file) + + args, kwargs = mock_generator_cls.call_args + assert "generate_options" in kwargs + options_passed = kwargs["generate_options"] + + assert options_passed.generated_file == generated_output_file + assert options_passed.max_qac == 50 + assert options_passed.provider == LlmProvider.OPENAI + assert options_passed.api_key.get_secret_value() == "openaikey" + + mock_set_llm_opts.assert_called_once() + + mock_generator_instance.generate_from_sample.assert_called_once_with( + sample_input_file + ) + + +def test_generate_command_input_file_not_exist() -> None: + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + env_file = Path(tmpdir) / ".env" + env_file.touch() + result = runner.invoke( + app, ["generate", "nonexistent.jsonl", "--env-file", str(env_file)] + ) + assert result.exit_code != 0 + assert "Error: The input file nonexistent.jsonl does not exist." in result.stdout + + +def test_generate_command_env_file_not_exist() -> None: + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + sample_file = Path(tmpdir) / "sample.jsonl" + sample_file.touch() + result = runner.invoke( + app, ["generate", str(sample_file), "--env-file", "nonexistent.env"] + ) + assert result.exit_code != 0 + assert ( + "Error: The environment file nonexistent.env does not exist." in result.stdout + ) + + +@mock.patch("docling_sdg.cli.qa.logging.basicConfig") +@mock.patch("docling_sdg.cli.qa.load_dotenv") +@mock.patch("docling_sdg.cli.qa.set_llm_options_from_env") +@mock.patch("docling_sdg.cli.qa.Generator") +def test_generate_command_verbosity_vv( + mock_generator_cls: mock.Mock, + mock_set_llm: mock.Mock, + mock_load_env: mock.Mock, + mock_log_config: mock.Mock, +) -> None: + runner = CliRunner() + mock_generator_instance = mock.Mock() + mock_generator_cls.return_value = mock_generator_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + sample_input_file = workdir / "sample_passages_debug.jsonl" + sample_input_file.touch() + env_file = workdir / ".env.debug" + env_file.touch() + + with mock.patch.dict(os.environ, {"WATSONX_APIKEY": "testkey"}): + result = runner.invoke( + app, + [ + "generate", + str(sample_input_file), + "--env-file", + str(env_file), + "-vv", + ], + ) + + assert result.exit_code == 0, f"CLI Error: {result.stdout}" + mock_log_config.assert_any_call(level=logging.DEBUG) + + +# Tests for `critique` CLI command +@mock.patch("docling_sdg.cli.qa.load_dotenv") +@mock.patch("docling_sdg.cli.qa.set_llm_options_from_env") +@mock.patch("docling_sdg.cli.qa.Judge") +def test_critique_command_valid_input( + mock_judge_cls: mock.Mock, mock_set_llm_opts: mock.Mock, mock_load_dotenv: mock.Mock +) -> None: + runner = CliRunner() + mock_judge_instance = mock.Mock() + mock_judge_cls.return_value = mock_judge_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + qna_input_file = workdir / "generated_qna.jsonl" + qna_input_file.touch() + + env_file = workdir / ".env.critique" + env_file.touch() + + env_vars_for_llm = { + "WATSONX_APIKEY": "critique_apikey", + "WATSONX_PROJECT_ID": "critique_projectid", + } + + with mock.patch.dict(os.environ, env_vars_for_llm): + result = runner.invoke( + app, + ["critique", str(qna_input_file), "--env-file", str(env_file)], + ) + + assert result.exit_code == 0, f"CLI Error: {result.stdout}" + mock_load_dotenv.assert_called_once_with(env_file) + mock_judge_cls.assert_called_once() + + mock_set_llm_opts.assert_called_once() + + mock_judge_instance.critique.assert_called_once_with(qna_input_file) + assert "Q&A Critique finished" in result.stdout + + +@mock.patch("docling_sdg.cli.qa.load_dotenv") +@mock.patch("docling_sdg.cli.qa.set_llm_options_from_env") +@mock.patch("docling_sdg.cli.qa.Judge") +def test_critique_command_options_and_provider( + mock_judge_cls: mock.Mock, mock_set_llm_opts: mock.Mock, mock_load_dotenv: mock.Mock +) -> None: + runner = CliRunner() + mock_judge_instance = mock.Mock() + mock_judge_cls.return_value = mock_judge_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + qna_input_file = workdir / "qna_to_critique.jsonl" + qna_input_file.touch() + critiqued_output_file = workdir / "critiqued_qna.jsonl" + env_file = workdir / "custom_critique.env" + env_file.touch() + + env_vars_for_llm = {"OPENAI_LIKE_APIKEY": "openaikey"} + + with mock.patch.dict(os.environ, env_vars_for_llm): + result = runner.invoke( + app, + [ + "critique", + str(qna_input_file), + "--critiqued-file", + str(critiqued_output_file), + "--max-qac", + "25", + "--provider", + "OPENAI_LIKE", + "--env-file", + str(env_file), + ], + ) + + assert result.exit_code == 0, f"CLI Error: {result.stdout}" + mock_load_dotenv.assert_called_once_with(env_file) + + args, kwargs = mock_judge_cls.call_args + assert "critique_options" in kwargs + options_passed = kwargs["critique_options"] + + assert options_passed.critiqued_file == critiqued_output_file + assert options_passed.max_qac == 25 + assert options_passed.provider == LlmProvider.OPENAI_LIKE + assert options_passed.api_key.get_secret_value() == "openaikey" + + mock_set_llm_opts.assert_called_once() + + mock_judge_instance.critique.assert_called_once_with(qna_input_file) + + +def test_critique_command_input_file_not_exist() -> None: + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + env_file = Path(tmpdir) / ".env" + env_file.touch() + result = runner.invoke( + app, ["critique", "nonexistent_qna.jsonl", "--env-file", str(env_file)] + ) + assert result.exit_code != 0 + assert ( + "Error: The input file nonexistent_qna.jsonl does not exist." in result.stdout + ) + + +def test_critique_command_env_file_not_exist() -> None: + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + qna_file = Path(tmpdir) / "qna.jsonl" + qna_file.touch() + result = runner.invoke( + app, ["critique", str(qna_file), "--env-file", "nonexistent.env"] + ) + assert result.exit_code != 0 + assert ( + "Error: The environment file nonexistent.env does not exist." in result.stdout + ) + + +@mock.patch("docling_sdg.cli.qa.logging.basicConfig") +@mock.patch("docling_sdg.cli.qa.load_dotenv") +@mock.patch("docling_sdg.cli.qa.set_llm_options_from_env") +@mock.patch("docling_sdg.cli.qa.Judge") +def test_critique_command_verbosity_v( + mock_judge_cls: mock.Mock, + mock_set_llm: mock.Mock, + mock_load_env: mock.Mock, + mock_log_config: mock.Mock, +) -> None: + runner = CliRunner() + mock_judge_instance = mock.Mock() + mock_judge_cls.return_value = mock_judge_instance + + with tempfile.TemporaryDirectory() as tmpdir: + workdir = Path(tmpdir) + qna_input_file = workdir / "qna_critique_log.jsonl" + qna_input_file.touch() + env_file = workdir / ".env.critique_log" + env_file.touch() + + with mock.patch.dict(os.environ, {"WATSONX_APIKEY": "testkey"}): + result = runner.invoke( + app, + [ + "critique", + str(qna_input_file), + "--env-file", + str(env_file), + "-v", + ], + ) + + assert result.exit_code == 0, f"CLI Error: {result.stdout}" + mock_log_config.assert_any_call(level=logging.INFO) From f9d943ef45cb980b12f58439bbe0b1ae1ddaa4e4 Mon Sep 17 00:00:00 2001 From: Kush Gupta Date: Fri, 13 Jun 2025 15:40:47 -0400 Subject: [PATCH 6/6] ruff-format Signed-off-by: Kush Gupta --- tests/test_qa_cli.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_qa_cli.py b/tests/test_qa_cli.py index 3ef4335..edd71e8 100644 --- a/tests/test_qa_cli.py +++ b/tests/test_qa_cli.py @@ -18,6 +18,7 @@ # Assisted by: Jules (Gemini 2.5 pro) + def test_app_help() -> None: result = runner.invoke(app, ["--help"]) assert result.exit_code == 0