diff --git a/src/data_designer/cli/commands/download.py b/src/data_designer/cli/commands/download.py new file mode 100644 index 00000000..8cf2fa5e --- /dev/null +++ b/src/data_designer/cli/commands/download.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import typer + +from data_designer.cli.controllers.download_controller import DownloadController +from data_designer.config.utils.constants import DATA_DESIGNER_HOME + + +def personas_command( + locales: list[str] = typer.Option( + None, + "--locale", + "-l", + help="Locales to download (en_US, en_IN, hi_Deva_IN, hi_Latn_IN, ja_JP). Can be specified multiple times.", + ), + all_locales: bool = typer.Option( + False, + "--all", + help="Download all available locales", + ), + dry_run: bool = typer.Option( + False, + "--dry-run", + help="Show what would be downloaded without actually downloading", + ), + list_available: bool = typer.Option( + False, + "--list", + help="List available persona datasets and their sizes", + ), +) -> None: + """Download Nemotron-Personas datasets for synthetic data generation. + + Examples: + # List available datasets + data-designer download personas --list + + # Interactive selection + data-designer download personas + + # Download specific locales + data-designer download personas --locale en_US --locale ja_JP + + # Download all available locales + data-designer download personas --all + + # Preview what would be downloaded + data-designer download personas --all --dry-run + """ + controller = DownloadController(DATA_DESIGNER_HOME) + + if list_available: + controller.list_personas() + else: + controller.run_personas(locales=locales, all_locales=all_locales, dry_run=dry_run) diff --git a/src/data_designer/cli/controllers/__init__.py b/src/data_designer/cli/controllers/__init__.py index f44c2635..aefe829b 100644 --- a/src/data_designer/cli/controllers/__init__.py +++ b/src/data_designer/cli/controllers/__init__.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from data_designer.cli.controllers.download_controller import DownloadController from data_designer.cli.controllers.model_controller import ModelController from data_designer.cli.controllers.provider_controller import ProviderController -__all__ = ["ModelController", "ProviderController"] +__all__ = ["DownloadController", "ModelController", "ProviderController"] diff --git a/src/data_designer/cli/controllers/download_controller.py b/src/data_designer/cli/controllers/download_controller.py new file mode 100644 index 00000000..47bf619c --- /dev/null +++ b/src/data_designer/cli/controllers/download_controller.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import subprocess +from pathlib import Path + +from data_designer.cli.repositories.persona_repository import PersonaRepository +from data_designer.cli.services.download_service import DownloadService +from data_designer.cli.ui import ( + confirm_action, + console, + print_error, + print_header, + print_info, + print_success, + print_text, + select_multiple_with_arrows, +) +from data_designer.cli.utils import check_ngc_cli_available, get_ngc_version + +NGC_URL = "https://catalog.ngc.nvidia.com/" +NGC_CLI_INSTALL_URL = "https://org.ngc.nvidia.com/setup/installers/cli" + + +class DownloadController: + """Controller for asset download workflows.""" + + def __init__(self, config_dir: Path): + self.config_dir = config_dir + self.persona_repository = PersonaRepository() + self.service = DownloadService(config_dir, self.persona_repository) + + def list_personas(self) -> None: + """List available persona datasets and their sizes.""" + print_header("Available Nemotron-Persona Datasets") + console.print() + + available_locales = self.persona_repository.list_all() + + print_text("📦 Available locales:") + console.print() + + for locale in available_locales: + already_downloaded = self.service.is_locale_downloaded(locale.code) + status = " (downloaded)" if already_downloaded else "" + print_text(f" • {locale.code}: {locale.size}{status}") + + console.print() + print_info(f"Total: {len(available_locales)} datasets available") + + def run_personas(self, locales: list[str] | None, all_locales: bool, dry_run: bool = False) -> None: + """Main entry point for persona dataset downloads. + + Args: + locales: List of locale codes to download (if provided via CLI flags) + all_locales: If True, download all available locales + dry_run: If True, only show what would be downloaded without actually downloading + """ + header = "Download Nemotron-Persona Datasets (Dry Run)" if dry_run else "Download Nemotron-Persona Datasets" + print_header(header) + print_info(f"Datasets will be saved to: {self.service.get_managed_assets_directory()}") + console.print() + + # Check NGC CLI availability (skip checking in dry run mode) + if not dry_run and not check_ngc_cli_with_instructions(): + return + + # Determine which locales to download + selected_locales = self._determine_locales(locales, all_locales) + + if not selected_locales: + print_info("No locales selected") + return + + # Show what will be downloaded + console.print() + action = "Would download" if dry_run else "Will download" + print_text(f"📦 {action} {len(selected_locales)} Nemotron-Persona dataset(s):") + for locale_code in selected_locales: + locale = self.persona_repository.get_by_code(locale_code) + already_downloaded = self.service.is_locale_downloaded(locale_code) + status = " - already exists, will update" if already_downloaded else "" + size = locale.size if locale else "unknown" + print_text(f" • {locale_code} ({size}){status}") + + console.print() + + # In dry run mode, exit here + if dry_run: + print_info("Dry run complete - no files were downloaded") + return + + # Confirm download + if not confirm_action("Proceed with download?", default=True): + print_info("Download cancelled") + return + + # Download each locale + console.print() + successful = [] + failed = [] + + for locale in selected_locales: + if self._download_locale(locale): + successful.append(locale) + else: + failed.append(locale) + + # Summary + console.print() + if successful: + print_success(f"Successfully downloaded {len(successful)} dataset(s): {', '.join(successful)}") + print_info(f"Saved datasets to: {self.service.get_managed_assets_directory()}") + + if failed: + print_error(f"Failed to download {len(failed)} dataset(s): {', '.join(failed)}") + + def _determine_locales(self, locales: list[str] | None, all_locales: bool) -> list[str]: + """Determine which locales to download based on user input. + + Args: + locales: List of locales from CLI flags (may be None) + all_locales: Whether to download all locales + + Returns: + List of locale codes to download + """ + available_locales = self.service.get_available_locales() + + # If --all flag is set, return all locales + if all_locales: + return list(available_locales.keys()) + + # If locales specified via flags, validate and return them + if locales: + invalid_locales = [loc for loc in locales if loc not in available_locales] + if invalid_locales: + print_error(f"Invalid locale(s): {', '.join(invalid_locales)}") + print_info(f"Available locales: {', '.join(available_locales.keys())}") + return [] + return locales + + # Interactive multi-select + return self._select_locales_interactive(available_locales) + + def _select_locales_interactive(self, available_locales: dict[str, str]) -> list[str]: + """Interactive multi-select for locales. + + Args: + available_locales: Dictionary of {locale_code: description} + + Returns: + List of selected locale codes + """ + console.print() + print_text("Select locales you want to download:") + console.print() + + selected = select_multiple_with_arrows( + options=available_locales, + prompt_text="Use ↑/↓ to navigate, Space to toggle ✓, Enter to confirm:", + default_keys=None, + allow_empty=False, + ) + + return selected if selected else [] + + def _download_locale(self, locale: str) -> bool: + """Download a single locale using NGC CLI. + + Args: + locale: Locale code to download + + Returns: + True if download succeeded, False otherwise + """ + # Print header before download (NGC CLI will show its own progress) + print_text(f"📦 Downloading Nemotron-Persona dataset for {locale}...") + console.print() + + try: + self.service.download_persona_dataset(locale) + console.print() + print_success(f"✓ Downloaded Nemotron-Persona dataset for {locale}") + return True + + except subprocess.CalledProcessError as e: + console.print() + print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}") + print_error(f"NGC CLI error: {e}") + return False + + except Exception as e: + console.print() + print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}") + print_error(f"Unexpected error: {e}") + return False + + +def check_ngc_cli_with_instructions() -> bool: + """Check if NGC CLI is installed and guide user if not.""" + if check_ngc_cli_available(): + version = get_ngc_version() + if version: + print_info(version) + return True + + print_error("NGC CLI not found!") + console.print() + print_text("The NGC CLI is required to download the Nemotron-Personas datasets.") + console.print() + print_text("To download the Nemotron-Personas datasets, follow these steps:") + print_text(f" 1. Create an NVIDIA NGC account: {NGC_URL}") + print_text(f" 2. Install the NGC CLI: {NGC_CLI_INSTALL_URL}") + print_text(" 3. Following the install instructions to set up the NGC CLI") + print_text(" 4. Run 'data-designer download personas'") + return False diff --git a/src/data_designer/cli/main.py b/src/data_designer/cli/main.py index 79932a94..4e3d53e3 100644 --- a/src/data_designer/cli/main.py +++ b/src/data_designer/cli/main.py @@ -3,8 +3,8 @@ import typer +from data_designer.cli.commands import download, models, providers, reset from data_designer.cli.commands import list as list_cmd -from data_designer.cli.commands import models, providers, reset from data_designer.config.default_model_settings import resolve_seed_default_model_settings from data_designer.config.utils.misc import can_run_data_designer_locally @@ -32,7 +32,17 @@ config_app.command(name="list", help="List current configurations")(list_cmd.list_command) config_app.command(name="reset", help="Reset configuration files")(reset.reset_command) +# Create download command group +download_app = typer.Typer( + name="download", + help="Download assets for Data Designer", + no_args_is_help=True, +) +download_app.command(name="personas", help="Download Nemotron-Persona datasets")(download.personas_command) + +# Add command groups to main app app.add_typer(config_app, name="config") +app.add_typer(download_app, name="download") def main() -> None: diff --git a/src/data_designer/cli/repositories/persona_repository.py b/src/data_designer/cli/repositories/persona_repository.py new file mode 100644 index 00000000..40886dfc --- /dev/null +++ b/src/data_designer/cli/repositories/persona_repository.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pydantic import BaseModel + +from data_designer.config.utils.constants import ( + NEMOTRON_PERSONAS_DATASET_PREFIX, + NEMOTRON_PERSONAS_DATASET_SIZES, +) + + +class PersonaLocale(BaseModel): + """Metadata for a single persona locale.""" + + code: str + size: str + dataset_name: str + + +class PersonaLocaleRegistry(BaseModel): + """Registry for available persona locales.""" + + locales: list[PersonaLocale] + dataset_prefix: str = NEMOTRON_PERSONAS_DATASET_PREFIX + + +class PersonaRepository: + """Repository for persona locale metadata. + + This repository provides access to built-in persona locale metadata. + Unlike ConfigRepository subclasses, this is read-only reference data + about what's available in NGC, not user configuration. + """ + + def __init__(self) -> None: + """Initialize repository with built-in locale metadata.""" + self._registry = self._initialize_registry() + + def _initialize_registry(self) -> PersonaLocaleRegistry: + """Initialize registry from constants.""" + locales = [ + PersonaLocale( + code=code, + size=size, + dataset_name=f"{NEMOTRON_PERSONAS_DATASET_PREFIX}{code.lower()}", + ) + for code, size in NEMOTRON_PERSONAS_DATASET_SIZES.items() + ] + return PersonaLocaleRegistry(locales=locales) + + def list_all(self) -> list[PersonaLocale]: + """Get all available persona locales. + + Returns: + List of all available persona locales + """ + return list(self._registry.locales) + + def get_by_code(self, code: str) -> PersonaLocale | None: + """Get a specific locale by code. + + Args: + code: Locale code (e.g., 'en_US', 'ja_JP') + + Returns: + PersonaLocale if found, None otherwise + """ + return next((locale for locale in self._registry.locales if locale.code == code), None) + + def get_dataset_name(self, code: str) -> str | None: + """Get the NGC dataset name for a locale. + + Args: + code: Locale code (e.g., 'en_US', 'ja_JP') + + Returns: + Dataset name if locale exists, None otherwise + """ + locale = self.get_by_code(code) + return locale.dataset_name if locale else None + + def get_dataset_prefix(self) -> str: + """Get the dataset prefix for all persona datasets. + + Returns: + Dataset prefix string + """ + return self._registry.dataset_prefix diff --git a/src/data_designer/cli/services/__init__.py b/src/data_designer/cli/services/__init__.py index b5f3c80f..35482158 100644 --- a/src/data_designer/cli/services/__init__.py +++ b/src/data_designer/cli/services/__init__.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from data_designer.cli.services.download_service import DownloadService from data_designer.cli.services.model_service import ModelService from data_designer.cli.services.provider_service import ProviderService -__all__ = ["ModelService", "ProviderService"] +__all__ = ["DownloadService", "ModelService", "ProviderService"] diff --git a/src/data_designer/cli/services/download_service.py b/src/data_designer/cli/services/download_service.py new file mode 100644 index 00000000..556efa21 --- /dev/null +++ b/src/data_designer/cli/services/download_service.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import glob +import shutil +import subprocess +import tempfile +from pathlib import Path + +from data_designer.cli.repositories.persona_repository import PersonaRepository + + +class DownloadService: + """Business logic for downloading assets via NGC CLI.""" + + def __init__(self, config_dir: Path, persona_repository: PersonaRepository): + self.config_dir = config_dir + self.managed_assets_dir = config_dir / "managed-assets" / "datasets" + self.persona_repository = persona_repository + + def get_available_locales(self) -> dict[str, str]: + """Get dictionary of available persona locales (locale code -> locale code).""" + locales = self.persona_repository.list_all() + return {locale.code: locale.code for locale in locales} + + def download_persona_dataset(self, locale: str) -> Path: + """Download persona dataset for a specific locale using NGC CLI and move to managed assets. + + Args: + locale: Locale code (e.g., 'en_US', 'ja_JP') + + Returns: + Path to the managed assets datasets directory + + Raises: + ValueError: If locale is invalid + subprocess.CalledProcessError: If NGC CLI command fails + """ + locale_obj = self.persona_repository.get_by_code(locale) + if not locale_obj: + raise ValueError(f"Invalid locale: {locale}") + + self.managed_assets_dir.mkdir(parents=True, exist_ok=True) + + # Use temporary directory for download + with tempfile.TemporaryDirectory() as temp_dir: + # Run NGC CLI download command (without version to get latest) + cmd = [ + "ngc", + "registry", + "resource", + "download-version", + f"nvidia/nemotron-personas/{locale_obj.dataset_name}", + "--dest", + temp_dir, + ] + + subprocess.run(cmd, check=True) + + download_pattern = f"{temp_dir}/{locale_obj.dataset_name}*/*.parquet" + parquet_files = glob.glob(download_pattern) + + if not parquet_files: + raise FileNotFoundError(f"No parquet files found matching pattern: {download_pattern}") + + # Move each parquet file to managed assets + for parquet_file in parquet_files: + source = Path(parquet_file) + dest = self.managed_assets_dir / source.name + shutil.move(str(source), str(dest)) + + return self.managed_assets_dir + + def get_managed_assets_directory(self) -> Path: + """Get the directory where managed datasets are stored.""" + return self.managed_assets_dir + + def is_locale_downloaded(self, locale: str) -> bool: + """Check if a locale has already been downloaded to managed assets. + + Args: + locale: Locale code to check + + Returns: + True if the locale dataset exists in managed assets + """ + locale_obj = self.persona_repository.get_by_code(locale) + if not locale_obj: + return False + + if not self.managed_assets_dir.exists(): + return False + + # Look for any parquet files that start with the dataset pattern + parquet_files = glob.glob(str(self.managed_assets_dir / f"{locale}.parquet")) + + return len(parquet_files) > 0 diff --git a/src/data_designer/cli/ui.py b/src/data_designer/cli/ui.py index a8eccb7c..7a7b9291 100644 --- a/src/data_designer/cli/ui.py +++ b/src/data_designer/cli/ui.py @@ -182,6 +182,137 @@ def _cancel(event) -> None: return None +def select_multiple_with_arrows( + options: dict[str, str], + prompt_text: str, + default_keys: list[str] | None = None, + allow_empty: bool = False, +) -> list[str] | None: + """Interactive multi-selection with arrow key navigation and space to toggle. + + Uses prompt_toolkit's Application for an inline checkbox-style menu experience. + + Args: + options: Dictionary of {key: display_text} options + prompt_text: Prompt to display above options + default_keys: List of keys that should be pre-selected + allow_empty: If True, allows user to submit with no selections + + Returns: + List of selected keys, or None if cancelled + """ + if not options: + return None + + # Build list of keys and track selected state + keys = list(options.keys()) + selected_set = set(default_keys) if default_keys else set() + current_index = 0 + + # Store result + result = {"value": None, "cancelled": False} + + def get_formatted_text() -> list[tuple[str, str]]: + """Generate the formatted text for the multi-select menu.""" + text = [] + # Add prompt with padding + padding = " " * LEFT_PADDING + text.append(("", f"{padding}{prompt_text}\n")) + + # Add options with checkboxes + for i, key in enumerate(keys): + display = options[key] + checkbox = "[✓]" if key in selected_set else "[ ]" + + if i == current_index: + # Highlighted item with Nord8 color + text.append((f"fg:{NordColor.NORD8.value} bold", f"{padding} → {checkbox} {display}\n")) + else: + # Unselected item + text.append(("", f"{padding} {checkbox} {display}\n")) + + # Add hint + count = len(selected_set) + text.append( + ( + "fg:#666666", + f"{padding} (↑/↓: navigate, Space: toggle, Enter: confirm ({count} selected), Esc: cancel)\n", + ) + ) + return text + + # Create key bindings + kb = KeyBindings() + + @kb.add("up") + @kb.add("c-p") # Ctrl+P + def _move_up(event) -> None: + nonlocal current_index + current_index = (current_index - 1) % len(keys) + + @kb.add("down") + @kb.add("c-n") # Ctrl+N + def _move_down(event) -> None: + nonlocal current_index + current_index = (current_index + 1) % len(keys) + + @kb.add("c-h") # Ctrl+H as alternative + @kb.add(" ", eager=True) # Space key - eager to capture immediately + def _toggle(event) -> None: + key = keys[current_index] + if key in selected_set: + selected_set.remove(key) + else: + selected_set.add(key) + + @kb.add("enter") + def _confirm(event) -> None: + if not allow_empty and not selected_set: + # Don't allow empty selection if not permitted + return + result["value"] = list(selected_set) + event.app.exit() + + @kb.add("escape") + @kb.add("c-c") # Ctrl+C + def _cancel(event) -> None: + result["cancelled"] = True + event.app.exit() + + # Create the application + app = Application( + layout=Layout( + HSplit( + [ + Window( + content=FormattedTextControl(get_formatted_text), + dont_extend_height=True, + always_hide_cursor=True, + ) + ] + ) + ), + key_bindings=kb, + full_screen=False, + mouse_support=False, + ) + + try: + # Run the application + app.run() + + # Handle the result + if result["cancelled"]: + print_warning("Cancelled") + return None + else: + return result["value"] + + except (KeyboardInterrupt, EOFError): + print_warning("Cancelled") + return None + + def prompt_text_input( prompt_msg: str, default: str | None = None, diff --git a/src/data_designer/cli/utils.py b/src/data_designer/cli/utils.py index fa2e20a8..5c2dec23 100644 --- a/src/data_designer/cli/utils.py +++ b/src/data_designer/cli/utils.py @@ -1,6 +1,40 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import shutil +import subprocess + + +def check_ngc_cli_available() -> bool: + """Check if NGC CLI is installed and available. + + Returns: + True if NGC CLI is in PATH and executable, False otherwise. + """ + if shutil.which("ngc") is None: + return False + + return get_ngc_version() is not None + + +def get_ngc_version() -> str | None: + """Get the NGC CLI version if available. + + Returns: + NGC CLI version string if available, None otherwise. + """ + try: + result = subprocess.run( + ["ngc", "--version"], + capture_output=True, + text=True, + check=True, + timeout=5, + ) + return result.stdout.strip() + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + return None + def validate_url(url: str) -> bool: """Validate that a string is a valid URL. diff --git a/src/data_designer/config/sampler_params.py b/src/data_designer/config/sampler_params.py index dc2af8fd..51ba3058 100644 --- a/src/data_designer/config/sampler_params.py +++ b/src/data_designer/config/sampler_params.py @@ -421,8 +421,8 @@ class PersonSamplerParams(ConfigBase): Attributes: locale: Locale string determining the language and geographic region for synthetic people. - Format: language_COUNTRY (e.g., "en_US", "en_GB", "fr_FR", "de_DE", "es_ES", "ja_JP"). - Defaults to "en_US". + Must be a locale supported by a managed Nemotron Personas dataset. The dataset must + be downloaded and available in the managed assets directory. sex: If specified, filters to only sample people of the specified sex. Options: "Male" or "Female". If None, samples both sexes. city: If specified, filters to only sample people from the specified city or cities. Can be diff --git a/src/data_designer/config/utils/constants.py b/src/data_designer/config/utils/constants.py index 1f3efdda..bbb01f99 100644 --- a/src/data_designer/config/utils/constants.py +++ b/src/data_designer/config/utils/constants.py @@ -97,8 +97,6 @@ class NordColor(Enum): MIN_AGE = 0 MAX_AGE = 114 -LOCALES_WITH_MANAGED_DATASETS = ["en_US", "ja_JP", "en_IN", "hi_IN"] - US_STATES_AND_MAJOR_TERRITORIES = { # States "AK", @@ -323,3 +321,16 @@ class NordColor(Enum): "embedding": {"model": "text-embedding-3-large", "inference_parameters": DEFAULT_EMBEDDING_INFERENCE_PARAMS}, }, } + +# Persona locale metadata - used by the CLI and the person sampler. +NEMOTRON_PERSONAS_DATASET_SIZES = { + "en_US": "1.24 GB", + "en_IN": "2.39 GB", + "hi_Deva_IN": "4.14 GB", + "hi_Latn_IN": "2.7 GB", + "ja_JP": "1.69 GB", +} + +LOCALES_WITH_MANAGED_DATASETS = list[str](NEMOTRON_PERSONAS_DATASET_SIZES.keys()) + +NEMOTRON_PERSONAS_DATASET_PREFIX = "nemotron-personas-dataset-" diff --git a/src/data_designer/engine/sampling_gen/people_gen.py b/src/data_designer/engine/sampling_gen/people_gen.py index 020eb428..b605fe66 100644 --- a/src/data_designer/engine/sampling_gen/people_gen.py +++ b/src/data_designer/engine/sampling_gen/people_gen.py @@ -13,7 +13,7 @@ import pandas as pd from faker import Faker -from data_designer.config.utils.constants import AVAILABLE_LOCALES, DEFAULT_AGE_RANGE +from data_designer.config.utils.constants import DEFAULT_AGE_RANGE from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator from data_designer.engine.sampling_gen.entities.dataset_based_person_fields import PERSONA_FIELDS, PII_FIELDS from data_designer.engine.sampling_gen.entities.person import ( @@ -34,10 +34,6 @@ class PeopleGen(ABC): """Unified interface for generating people data.""" def __init__(self, engine: EngineT, locale: str): - if locale not in AVAILABLE_LOCALES: - raise ValueError( - f"Locale {locale} is not a supported locale.Supported locales: {', '.join(AVAILABLE_LOCALES)}" - ) self.locale = locale self._engine = engine diff --git a/tests/cli/commands/test_download_command.py b/tests/cli/commands/test_download_command.py new file mode 100644 index 00000000..d1417457 --- /dev/null +++ b/tests/cli/commands/test_download_command.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock, patch + +from data_designer.cli.commands.download import personas_command +from data_designer.cli.controllers.download_controller import DownloadController +from data_designer.config.utils.constants import DATA_DESIGNER_HOME + + +@patch("data_designer.cli.commands.download.DownloadController") +def test_personas_command_interactive_mode(mock_download_controller: MagicMock) -> None: + """Test personas_command with no arguments (interactive mode).""" + mock_controller_instance = MagicMock(spec=DownloadController) + mock_download_controller.return_value = mock_controller_instance + + personas_command(locales=None, all_locales=False, dry_run=False, list_available=False) + + mock_download_controller.assert_called_once_with(DATA_DESIGNER_HOME) + mock_controller_instance.run_personas.assert_called_once_with(locales=None, all_locales=False, dry_run=False) + + +@patch("data_designer.cli.commands.download.DownloadController") +def test_personas_command_with_specific_locales(mock_download_controller: MagicMock) -> None: + """Test personas_command with --locale flags.""" + mock_controller_instance = MagicMock(spec=DownloadController) + mock_download_controller.return_value = mock_controller_instance + + personas_command(locales=["en_US", "ja_JP"], all_locales=False, dry_run=False, list_available=False) + + mock_download_controller.assert_called_once_with(DATA_DESIGNER_HOME) + mock_controller_instance.run_personas.assert_called_once_with( + locales=["en_US", "ja_JP"], all_locales=False, dry_run=False + ) + + +@patch("data_designer.cli.commands.download.DownloadController") +def test_personas_command_with_all_flag(mock_download_controller: MagicMock) -> None: + """Test personas_command with --all flag.""" + mock_controller_instance = MagicMock(spec=DownloadController) + mock_download_controller.return_value = mock_controller_instance + + personas_command(locales=None, all_locales=True, dry_run=False, list_available=False) + + mock_download_controller.assert_called_once_with(DATA_DESIGNER_HOME) + mock_controller_instance.run_personas.assert_called_once_with(locales=None, all_locales=True, dry_run=False) + + +@patch("data_designer.cli.commands.download.DownloadController") +def test_personas_command_with_dry_run_flag(mock_download_controller: MagicMock) -> None: + """Test personas_command with --dry-run flag.""" + mock_controller_instance = MagicMock(spec=DownloadController) + mock_download_controller.return_value = mock_controller_instance + + personas_command(locales=["en_US"], all_locales=False, dry_run=True, list_available=False) + + mock_download_controller.assert_called_once_with(DATA_DESIGNER_HOME) + mock_controller_instance.run_personas.assert_called_once_with(locales=["en_US"], all_locales=False, dry_run=True) + + +@patch("data_designer.cli.commands.download.DownloadController") +def test_personas_command_with_list_flag(mock_download_controller: MagicMock) -> None: + """Test personas_command with --list flag.""" + mock_controller_instance = MagicMock(spec=DownloadController) + mock_download_controller.return_value = mock_controller_instance + + personas_command(locales=None, all_locales=False, dry_run=False, list_available=True) + + mock_download_controller.assert_called_once_with(DATA_DESIGNER_HOME) + mock_controller_instance.list_personas.assert_called_once() + mock_controller_instance.run_personas.assert_not_called() diff --git a/tests/cli/controllers/test_download_controller.py b/tests/cli/controllers/test_download_controller.py new file mode 100644 index 00000000..2f7c5915 --- /dev/null +++ b/tests/cli/controllers/test_download_controller.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from data_designer.cli.controllers.download_controller import DownloadController + + +@pytest.fixture +def controller(tmp_path: Path) -> DownloadController: + """Create a controller instance for testing.""" + return DownloadController(tmp_path) + + +@pytest.fixture +def controller_with_datasets(tmp_path: Path) -> DownloadController: + """Create a controller instance with existing datasets.""" + controller = DownloadController(tmp_path) + # Create managed assets directory with sample parquet files + managed_assets_dir = tmp_path / "managed-assets" / "datasets" + managed_assets_dir.mkdir(parents=True, exist_ok=True) + (managed_assets_dir / "en_US.parquet").touch() + return controller + + +def test_init(tmp_path: Path) -> None: + """Test controller initialization sets up service correctly.""" + controller = DownloadController(tmp_path) + assert controller.config_dir == tmp_path + assert controller.service.config_dir == tmp_path + assert controller.persona_repository is not None + assert controller.service.persona_repository is controller.persona_repository + + +def test_list_personas(controller: DownloadController) -> None: + """Test list_personas displays all available datasets.""" + controller.list_personas() + # Method should complete without errors and print to console + + +def test_list_personas_with_downloaded_datasets(controller_with_datasets: DownloadController) -> None: + """Test list_personas shows (downloaded) status for existing datasets.""" + controller_with_datasets.list_personas() + # Method should complete without errors and show downloaded status + + +@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=False) +@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=["en_US"]) +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True) +def test_run_personas_user_cancels_confirmation( + mock_check_ngc: MagicMock, + mock_select: MagicMock, + mock_confirm: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas when user cancels at confirmation prompt.""" + controller.run_personas(locales=None, all_locales=False) + + # Verify NGC check was called + mock_check_ngc.assert_called_once() + + # Verify interactive selection was called + mock_select.assert_called_once() + + # Verify confirmation was requested + mock_confirm.assert_called_once() + + +@patch.object(DownloadController, "_download_locale", return_value=True) +@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True) +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True) +def test_run_personas_with_all_flag( + mock_check_ngc: MagicMock, + mock_confirm: MagicMock, + mock_download: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas with --all flag downloads all locales.""" + controller.run_personas(locales=None, all_locales=True) + + # Verify NGC check was called + mock_check_ngc.assert_called_once() + + # Verify all 5 locales were downloaded + assert mock_download.call_count == 5 + + # Verify each locale was downloaded + downloaded_locales = [call[0][0] for call in mock_download.call_args_list] + assert "en_US" in downloaded_locales + assert "en_IN" in downloaded_locales + assert "hi_Deva_IN" in downloaded_locales + assert "hi_Latn_IN" in downloaded_locales + assert "ja_JP" in downloaded_locales + + +@patch.object(DownloadController, "_download_locale", return_value=True) +@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True) +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True) +def test_run_personas_with_specific_locales( + mock_check_ngc: MagicMock, + mock_confirm: MagicMock, + mock_download: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas with specific locale flags.""" + controller.run_personas(locales=["en_US", "ja_JP"], all_locales=False) + + # Verify NGC check was called + mock_check_ngc.assert_called_once() + + # Verify only specified locales were downloaded + assert mock_download.call_count == 2 + downloaded_locales = [call[0][0] for call in mock_download.call_args_list] + assert "en_US" in downloaded_locales + assert "ja_JP" in downloaded_locales + + +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True) +def test_run_personas_with_invalid_locales( + mock_check_ngc: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas with invalid locale codes.""" + controller.run_personas(locales=["invalid_locale", "en_US"], all_locales=False) + + # Verify NGC check was called + mock_check_ngc.assert_called_once() + + # Function should exit early without attempting download + + +@patch.object(DownloadController, "_download_locale", return_value=True) +@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True) +@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=["en_US"]) +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True) +def test_run_personas_interactive_selection( + mock_check_ngc: MagicMock, + mock_select: MagicMock, + mock_confirm: MagicMock, + mock_download: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas with interactive locale selection.""" + controller.run_personas(locales=None, all_locales=False) + + # Verify NGC check was called + mock_check_ngc.assert_called_once() + + # Verify interactive selection was called + mock_select.assert_called_once() + + # Verify confirmation was requested + mock_confirm.assert_called_once() + + # Verify selected locale was downloaded + mock_download.assert_called_once_with("en_US") + + +@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=None) +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True) +def test_run_personas_interactive_cancelled( + mock_check_ngc: MagicMock, + mock_select: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas when user cancels interactive selection.""" + controller.run_personas(locales=None, all_locales=False) + + # Verify NGC check was called + mock_check_ngc.assert_called_once() + + # Verify interactive selection was called + mock_select.assert_called_once() + + # Function should exit early + + +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=False) +def test_run_personas_ngc_cli_not_available( + mock_check_ngc: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas exits early when NGC CLI is not available.""" + controller.run_personas(locales=["en_US"], all_locales=False) + + # Verify NGC check was called + mock_check_ngc.assert_called_once() + + +def test_check_ngc_cli_available_with_version() -> None: + """Test check_ngc_cli_with_instructions displays version when NGC CLI is available.""" + from data_designer.cli.controllers.download_controller import check_ngc_cli_with_instructions + + with patch("data_designer.cli.controllers.download_controller.check_ngc_cli_available", return_value=True): + with patch("data_designer.cli.controllers.download_controller.get_ngc_version", return_value="NGC CLI 3.41.4"): + result = check_ngc_cli_with_instructions() + + assert result is True + + +def test_check_ngc_cli_available_without_version() -> None: + """Test check_ngc_cli_with_instructions when version cannot be determined.""" + from data_designer.cli.controllers.download_controller import check_ngc_cli_with_instructions + + with patch("data_designer.cli.controllers.download_controller.check_ngc_cli_available", return_value=True): + with patch("data_designer.cli.controllers.download_controller.get_ngc_version", return_value=None): + result = check_ngc_cli_with_instructions() + + assert result is True + + +def test_determine_locales_with_all_flag(controller: DownloadController) -> None: + """Test _determine_locales returns all locales when all_locales=True.""" + result = controller._determine_locales(locales=None, all_locales=True) + + assert len(result) == 5 + assert "en_US" in result + assert "en_IN" in result + assert "hi_Deva_IN" in result + assert "hi_Latn_IN" in result + assert "ja_JP" in result + + +def test_determine_locales_with_valid_locale_flags(controller: DownloadController) -> None: + """Test _determine_locales with valid locale flags.""" + result = controller._determine_locales(locales=["en_US", "ja_JP"], all_locales=False) + + assert result == ["en_US", "ja_JP"] + + +def test_determine_locales_with_invalid_locale_flags(controller: DownloadController) -> None: + """Test _determine_locales with invalid locale flags.""" + result = controller._determine_locales(locales=["invalid", "en_US"], all_locales=False) + + assert result == [] + + +@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=["en_US"]) +def test_determine_locales_interactive(mock_select: MagicMock, controller: DownloadController) -> None: + """Test _determine_locales with interactive selection.""" + result = controller._determine_locales(locales=None, all_locales=False) + + assert result == ["en_US"] + mock_select.assert_called_once() + + +def test_select_locales_interactive(controller: DownloadController) -> None: + """Test _select_locales_interactive calls UI function correctly.""" + available_locales = controller.service.get_available_locales() + + with patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows") as mock_select: + mock_select.return_value = ["en_US", "ja_JP"] + result = controller._select_locales_interactive(available_locales) + + assert result == ["en_US", "ja_JP"] + mock_select.assert_called_once() + + +def test_download_locale_success(controller: DownloadController) -> None: + """Test _download_locale successfully downloads a locale.""" + with patch.object(controller.service, "download_persona_dataset"): + result = controller._download_locale("en_US") + + assert result is True + + +def test_download_locale_subprocess_error(controller: DownloadController) -> None: + """Test _download_locale handles subprocess errors.""" + with patch.object( + controller.service, + "download_persona_dataset", + side_effect=subprocess.CalledProcessError(1, "ngc"), + ): + result = controller._download_locale("en_US") + + assert result is False + + +def test_download_locale_generic_error(controller: DownloadController) -> None: + """Test _download_locale handles generic errors.""" + with patch.object( + controller.service, + "download_persona_dataset", + side_effect=Exception("Unexpected error"), + ): + result = controller._download_locale("en_US") + + assert result is False + + +@patch.object(DownloadController, "_download_locale") +@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True) +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True) +def test_run_personas_mixed_success_and_failure( + mock_check_ngc: MagicMock, + mock_confirm: MagicMock, + mock_download: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas with mixed success and failure results.""" + # First download succeeds, second fails + mock_download.side_effect = [True, False] + + controller.run_personas(locales=["en_US", "ja_JP"], all_locales=False) + + # Verify both locales were attempted + assert mock_download.call_count == 2 + + +@patch.object(DownloadController, "_download_locale", return_value=True) +@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True) +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True) +def test_run_personas_shows_existing_status( + mock_check_ngc: MagicMock, + mock_confirm: MagicMock, + mock_download: MagicMock, + controller_with_datasets: DownloadController, +) -> None: + """Test run_personas shows (already exists, will update) status for existing datasets.""" + controller_with_datasets.run_personas(locales=["en_US"], all_locales=False) + + # Verify download was attempted (it would show the "already exists" message) + mock_download.assert_called_once_with("en_US") + + +@patch.object(DownloadController, "_download_locale") +@patch("data_designer.cli.controllers.download_controller.confirm_action") +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions") +def test_run_personas_with_dry_run_flag( + mock_check_ngc: MagicMock, + mock_confirm: MagicMock, + mock_download: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas with --dry-run flag does not download or check NGC CLI.""" + controller.run_personas(locales=["en_US", "ja_JP"], all_locales=False, dry_run=True) + + # Verify NGC check was NOT called in dry run mode + mock_check_ngc.assert_not_called() + + # Verify confirmation was NOT requested in dry run mode + mock_confirm.assert_not_called() + + # Verify no downloads were attempted + mock_download.assert_not_called() + + +@patch.object(DownloadController, "_download_locale") +@patch("data_designer.cli.controllers.download_controller.confirm_action") +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions") +def test_run_personas_with_all_and_dry_run( + mock_check_ngc: MagicMock, + mock_confirm: MagicMock, + mock_download: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas with --all and --dry-run flags.""" + controller.run_personas(locales=None, all_locales=True, dry_run=True) + + # Verify NGC check was NOT called in dry run mode + mock_check_ngc.assert_not_called() + + # Verify confirmation was NOT requested in dry run mode + mock_confirm.assert_not_called() + + # Verify no downloads were attempted + mock_download.assert_not_called() + + +@patch.object(DownloadController, "_download_locale") +@patch("data_designer.cli.controllers.download_controller.confirm_action") +@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=["en_US"]) +@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions") +def test_run_personas_interactive_with_dry_run( + mock_check_ngc: MagicMock, + mock_select: MagicMock, + mock_confirm: MagicMock, + mock_download: MagicMock, + controller: DownloadController, +) -> None: + """Test run_personas with interactive selection and --dry-run flag.""" + controller.run_personas(locales=None, all_locales=False, dry_run=True) + + # Verify NGC check was NOT called in dry run mode + mock_check_ngc.assert_not_called() + + # Verify interactive selection WAS called (user still needs to select) + mock_select.assert_called_once() + + # Verify confirmation was NOT requested in dry run mode + mock_confirm.assert_not_called() + + # Verify no downloads were attempted + mock_download.assert_not_called() diff --git a/tests/cli/repositories/test_persona_repository.py b/tests/cli/repositories/test_persona_repository.py new file mode 100644 index 00000000..e1e1ba1a --- /dev/null +++ b/tests/cli/repositories/test_persona_repository.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from data_designer.cli.repositories.persona_repository import PersonaLocale, PersonaRepository + + +@pytest.fixture +def repository() -> PersonaRepository: + """Create a repository instance for testing.""" + return PersonaRepository() + + +def test_init(repository: PersonaRepository) -> None: + """Test repository initialization creates registry.""" + assert repository._registry is not None + assert len(repository._registry.locales) == 5 + assert repository._registry.dataset_prefix == "nemotron-personas-dataset-" + + +def test_list_all(repository: PersonaRepository) -> None: + """Test listing all available locales.""" + locales = repository.list_all() + + assert isinstance(locales, list) + assert len(locales) == 5 + + # Verify all expected locales are present + locale_codes = {locale.code for locale in locales} + assert locale_codes == {"en_US", "en_IN", "hi_Deva_IN", "hi_Latn_IN", "ja_JP"} + + # Verify each locale has required fields + for locale in locales: + assert isinstance(locale, PersonaLocale) + assert locale.code + assert locale.size + assert locale.dataset_name + + +def test_get_by_code_valid_locale(repository: PersonaRepository) -> None: + """Test getting a locale by valid code.""" + locale = repository.get_by_code("en_US") + + assert locale is not None + assert locale.code == "en_US" + assert locale.size == "1.24 GB" + assert locale.dataset_name == "nemotron-personas-dataset-en_us" + + +def test_get_by_code_all_locales(repository: PersonaRepository) -> None: + """Test getting each locale by code.""" + test_cases = [ + ("en_US", "1.24 GB", "nemotron-personas-dataset-en_us"), + ("en_IN", "2.39 GB", "nemotron-personas-dataset-en_in"), + ("hi_Deva_IN", "4.14 GB", "nemotron-personas-dataset-hi_deva_in"), + ("hi_Latn_IN", "2.7 GB", "nemotron-personas-dataset-hi_latn_in"), + ("ja_JP", "1.69 GB", "nemotron-personas-dataset-ja_jp"), + ] + + for code, expected_size, expected_dataset_name in test_cases: + locale = repository.get_by_code(code) + assert locale is not None + assert locale.code == code + assert locale.size == expected_size + assert locale.dataset_name == expected_dataset_name + + +def test_get_by_code_invalid_locale(repository: PersonaRepository) -> None: + """Test getting a locale by invalid code returns None.""" + locale = repository.get_by_code("invalid_locale") + assert locale is None + + +def test_get_by_code_case_sensitive(repository: PersonaRepository) -> None: + """Test that locale code lookup is case sensitive.""" + locale = repository.get_by_code("en_us") # lowercase + assert locale is None + + locale = repository.get_by_code("EN_US") # uppercase + assert locale is None + + +def test_get_dataset_name_valid_locale(repository: PersonaRepository) -> None: + """Test getting dataset name for valid locale.""" + dataset_name = repository.get_dataset_name("en_US") + assert dataset_name == "nemotron-personas-dataset-en_us" + + +def test_get_dataset_name_invalid_locale(repository: PersonaRepository) -> None: + """Test getting dataset name for invalid locale returns None.""" + dataset_name = repository.get_dataset_name("invalid_locale") + assert dataset_name is None + + +def test_get_dataset_name_lowercase_conversion(repository: PersonaRepository) -> None: + """Test that dataset names use lowercase locale codes.""" + # Verify that mixed-case locale codes result in lowercase dataset names + locale = repository.get_by_code("hi_Deva_IN") + assert locale is not None + assert locale.dataset_name == "nemotron-personas-dataset-hi_deva_in" + assert locale.dataset_name.islower() or "_" in locale.dataset_name + + +def test_get_dataset_prefix(repository: PersonaRepository) -> None: + """Test getting dataset prefix.""" + prefix = repository.get_dataset_prefix() + assert prefix == "nemotron-personas-dataset-" + + +def test_persona_locale_model() -> None: + """Test PersonaLocale Pydantic model.""" + locale = PersonaLocale( + code="en_US", + size="1.24 GB", + dataset_name="nemotron-personas-dataset-en_us", + ) + + assert locale.code == "en_US" + assert locale.size == "1.24 GB" + assert locale.dataset_name == "nemotron-personas-dataset-en_us" + + +def test_persona_locale_model_validation() -> None: + """Test PersonaLocale model validates required fields.""" + with pytest.raises(Exception): # Pydantic validation error + PersonaLocale(code="en_US") # Missing required fields + + +def test_repository_immutability(repository: PersonaRepository) -> None: + """Test that modifying returned list doesn't affect repository.""" + locales = repository.list_all() + original_count = len(locales) + + # Try to modify the returned list + locales.append( + PersonaLocale( + code="test", + size="1 GB", + dataset_name="test-dataset", + ) + ) + + # Verify repository data is unchanged + fresh_locales = repository.list_all() + assert len(fresh_locales) == original_count + + +def test_locale_size_formats(repository: PersonaRepository) -> None: + """Test that all locale sizes are in expected format.""" + locales = repository.list_all() + + for locale in locales: + # Verify size contains GB and is properly formatted + assert "GB" in locale.size + # Extract numeric part and verify it's valid + size_value = locale.size.replace(" GB", "").replace("GB", "") + assert float(size_value) > 0 + + +def test_dataset_name_consistency(repository: PersonaRepository) -> None: + """Test that all dataset names follow consistent pattern.""" + locales = repository.list_all() + prefix = repository.get_dataset_prefix() + + for locale in locales: + # All dataset names should start with the prefix + assert locale.dataset_name.startswith(prefix) + # All dataset names should end with lowercase locale code + expected_suffix = locale.code.lower() + assert locale.dataset_name.endswith(expected_suffix) diff --git a/tests/cli/services/test_download_service.py b/tests/cli/services/test_download_service.py new file mode 100644 index 00000000..6e3fce22 --- /dev/null +++ b/tests/cli/services/test_download_service.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, call, patch + +import pytest + +from data_designer.cli.repositories.persona_repository import PersonaRepository +from data_designer.cli.services.download_service import DownloadService + + +@pytest.fixture +def persona_repository() -> PersonaRepository: + """Create a persona repository instance for testing.""" + return PersonaRepository() + + +@pytest.fixture +def service(tmp_path: Path, persona_repository: PersonaRepository) -> DownloadService: + """Create a service instance for testing.""" + return DownloadService(tmp_path, persona_repository) + + +@pytest.fixture +def service_with_datasets(tmp_path: Path, persona_repository: PersonaRepository) -> DownloadService: + """Create a service instance with existing datasets.""" + service = DownloadService(tmp_path, persona_repository) + # Create managed assets directory with sample parquet files + managed_assets_dir = tmp_path / "managed-assets" / "datasets" + managed_assets_dir.mkdir(parents=True, exist_ok=True) + + # Create sample parquet files for en_US and ja_JP + (managed_assets_dir / "en_US.parquet").touch() + (managed_assets_dir / "ja_JP.parquet").touch() + + return service + + +def test_init(tmp_path: Path, persona_repository: PersonaRepository) -> None: + """Test service initialization sets up paths correctly.""" + service = DownloadService(tmp_path, persona_repository) + assert service.config_dir == tmp_path + assert service.managed_assets_dir == tmp_path / "managed-assets" / "datasets" + assert service.persona_repository is persona_repository + + +def test_get_available_locales(service: DownloadService) -> None: + """Test getting available locales returns correct dictionary.""" + locales = service.get_available_locales() + + assert isinstance(locales, dict) + assert len(locales) == 5 + assert "en_US" in locales + assert "en_IN" in locales + assert "hi_Deva_IN" in locales + assert "hi_Latn_IN" in locales + assert "ja_JP" in locales + + # Verify values are locale codes (not descriptions) + assert locales["en_US"] == "en_US" + assert locales["ja_JP"] == "ja_JP" + + +def test_get_managed_assets_directory(service: DownloadService, tmp_path: Path) -> None: + """Test getting managed assets directory path.""" + expected = tmp_path / "managed-assets" / "datasets" + assert service.get_managed_assets_directory() == expected + + +def test_is_locale_downloaded_returns_true(service_with_datasets: DownloadService) -> None: + """Test checking if locale is downloaded when files exist.""" + assert service_with_datasets.is_locale_downloaded("en_US") is True + assert service_with_datasets.is_locale_downloaded("ja_JP") is True + + +def test_is_locale_downloaded_returns_false(service_with_datasets: DownloadService) -> None: + """Test checking if locale is downloaded when files don't exist.""" + assert service_with_datasets.is_locale_downloaded("en_IN") is False + assert service_with_datasets.is_locale_downloaded("hi_Deva_IN") is False + + +def test_is_locale_downloaded_invalid_locale(service: DownloadService) -> None: + """Test checking if invalid locale is downloaded.""" + assert service.is_locale_downloaded("invalid_locale") is False + + +def test_is_locale_downloaded_no_directory(service: DownloadService) -> None: + """Test checking if locale is downloaded when directory doesn't exist.""" + assert service.is_locale_downloaded("en_US") is False + + +def test_download_persona_dataset_invalid_locale(service: DownloadService) -> None: + """Test downloading with invalid locale raises ValueError.""" + with pytest.raises(ValueError, match="Invalid locale: invalid_locale"): + service.download_persona_dataset("invalid_locale") + + +@patch("data_designer.cli.services.download_service.glob.glob") +@patch("data_designer.cli.services.download_service.subprocess.run") +@patch("data_designer.cli.services.download_service.tempfile.TemporaryDirectory") +def test_download_persona_dataset_success( + mock_temp_dir: MagicMock, + mock_subprocess: MagicMock, + mock_glob: MagicMock, + service: DownloadService, + tmp_path: Path, +) -> None: + """Test successful persona dataset download.""" + # Setup mock temporary directory + temp_dir_path = "/tmp/test_temp_dir" + mock_temp_dir_instance = MagicMock() + mock_temp_dir_instance.__enter__.return_value = temp_dir_path + mock_temp_dir_instance.__exit__.return_value = None + mock_temp_dir.return_value = mock_temp_dir_instance + + # Setup mock parquet files + mock_parquet_files = [ + f"{temp_dir_path}/nemotron-personas-dataset-en_us_v0.0.1/file_0.parquet", + f"{temp_dir_path}/nemotron-personas-dataset-en_us_v0.0.1/file_1.parquet", + ] + mock_glob.return_value = mock_parquet_files + + # Mock shutil.move to avoid actual file operations + with patch("data_designer.cli.services.download_service.shutil.move") as mock_move: + result = service.download_persona_dataset("en_US") + + # Verify subprocess was called correctly + expected_cmd = [ + "ngc", + "registry", + "resource", + "download-version", + "nvidia/nemotron-personas/nemotron-personas-dataset-en_us", + "--dest", + temp_dir_path, + ] + mock_subprocess.assert_called_once_with(expected_cmd, check=True) + + # Verify glob pattern + expected_pattern = f"{temp_dir_path}/nemotron-personas-dataset-en_us*/*.parquet" + mock_glob.assert_called_once_with(expected_pattern) + + # Verify files were moved + assert mock_move.call_count == 2 + expected_calls = [ + call(mock_parquet_files[0], str(service.managed_assets_dir / "file_0.parquet")), + call(mock_parquet_files[1], str(service.managed_assets_dir / "file_1.parquet")), + ] + mock_move.assert_has_calls(expected_calls) + + # Verify result + assert result == service.managed_assets_dir + + # Verify managed assets directory was created + assert service.managed_assets_dir.exists() + + +@patch("data_designer.cli.services.download_service.glob.glob") +@patch("data_designer.cli.services.download_service.subprocess.run") +@patch("data_designer.cli.services.download_service.tempfile.TemporaryDirectory") +def test_download_persona_dataset_no_parquet_files( + mock_temp_dir: MagicMock, + mock_subprocess: MagicMock, + mock_glob: MagicMock, + service: DownloadService, +) -> None: + """Test download fails when no parquet files are found.""" + # Setup mock temporary directory + temp_dir_path = "/tmp/test_temp_dir" + mock_temp_dir_instance = MagicMock() + mock_temp_dir_instance.__enter__.return_value = temp_dir_path + mock_temp_dir_instance.__exit__.return_value = None + mock_temp_dir.return_value = mock_temp_dir_instance + + # Mock glob to return empty list + mock_glob.return_value = [] + + # Should raise FileNotFoundError + with pytest.raises(FileNotFoundError, match="No parquet files found matching pattern"): + service.download_persona_dataset("en_US") + + +@patch("data_designer.cli.services.download_service.subprocess.run") +@patch("data_designer.cli.services.download_service.tempfile.TemporaryDirectory") +def test_download_persona_dataset_ngc_cli_error( + mock_temp_dir: MagicMock, + mock_subprocess: MagicMock, + service: DownloadService, +) -> None: + """Test download handles NGC CLI subprocess errors.""" + # Setup mock temporary directory + temp_dir_path = "/tmp/test_temp_dir" + mock_temp_dir_instance = MagicMock() + mock_temp_dir_instance.__enter__.return_value = temp_dir_path + mock_temp_dir_instance.__exit__.return_value = None + mock_temp_dir.return_value = mock_temp_dir_instance + + # Mock subprocess to raise CalledProcessError + mock_subprocess.side_effect = subprocess.CalledProcessError(1, "ngc") + + # Should propagate the error + with pytest.raises(subprocess.CalledProcessError): + service.download_persona_dataset("en_US") + + +@patch("data_designer.cli.services.download_service.glob.glob") +@patch("data_designer.cli.services.download_service.subprocess.run") +@patch("data_designer.cli.services.download_service.tempfile.TemporaryDirectory") +def test_download_persona_dataset_multiple_locales( + mock_temp_dir: MagicMock, + mock_subprocess: MagicMock, + mock_glob: MagicMock, + service: DownloadService, +) -> None: + """Test downloading multiple different locales.""" + # Setup mock temporary directory + temp_dir_path = "/tmp/test_temp_dir" + mock_temp_dir_instance = MagicMock() + mock_temp_dir_instance.__enter__.return_value = temp_dir_path + mock_temp_dir_instance.__exit__.return_value = None + mock_temp_dir.return_value = mock_temp_dir_instance + + with patch("data_designer.cli.services.download_service.shutil.move"): + # Download en_US + mock_glob.return_value = [f"{temp_dir_path}/nemotron-personas-dataset-en_us_v0.0.1/file.parquet"] + service.download_persona_dataset("en_US") + + # Download ja_JP + mock_glob.return_value = [f"{temp_dir_path}/nemotron-personas-dataset-ja_jp_v0.0.1/file.parquet"] + service.download_persona_dataset("ja_JP") + + # Verify both locales were downloaded with correct resources + assert mock_subprocess.call_count == 2 + + # Check first call was for en_US + first_call_args = mock_subprocess.call_args_list[0][0][0] + assert "nvidia/nemotron-personas/nemotron-personas-dataset-en_us" in first_call_args + + # Check second call was for ja_JP + second_call_args = mock_subprocess.call_args_list[1][0][0] + assert "nvidia/nemotron-personas/nemotron-personas-dataset-ja_jp" in second_call_args + + +@patch("data_designer.cli.services.download_service.glob.glob") +@patch("data_designer.cli.services.download_service.subprocess.run") +@patch("data_designer.cli.services.download_service.tempfile.TemporaryDirectory") +def test_download_persona_dataset_lowercase_handling( + mock_temp_dir: MagicMock, + mock_subprocess: MagicMock, + mock_glob: MagicMock, + service: DownloadService, +) -> None: + """Test that glob pattern uses lowercase locale for NGC directory naming.""" + # Setup mock temporary directory + temp_dir_path = "/tmp/test_temp_dir" + mock_temp_dir_instance = MagicMock() + mock_temp_dir_instance.__enter__.return_value = temp_dir_path + mock_temp_dir_instance.__exit__.return_value = None + mock_temp_dir.return_value = mock_temp_dir_instance + + mock_glob.return_value = [f"{temp_dir_path}/nemotron-personas-dataset-hi_deva_in_v0.0.1/file.parquet"] + + with patch("data_designer.cli.services.download_service.shutil.move"): + service.download_persona_dataset("hi_Deva_IN") + + # Verify glob was called with lowercase locale + expected_pattern = f"{temp_dir_path}/nemotron-personas-dataset-hi_deva_in*/*.parquet" + mock_glob.assert_called_once_with(expected_pattern) diff --git a/tests/cli/test_cli_utils.py b/tests/cli/test_cli_utils.py index ece84e59..623ef18d 100644 --- a/tests/cli/test_cli_utils.py +++ b/tests/cli/test_cli_utils.py @@ -1,11 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import subprocess from pathlib import Path +from unittest.mock import MagicMock, patch import pytest from data_designer.cli.utils import ( + check_ngc_cli_available, + get_ngc_version, validate_numeric_range, validate_url, ) @@ -109,3 +113,72 @@ def test_validate_numeric_range() -> None: is_valid, value = validate_numeric_range("abc", 0.0, 1.0) assert not is_valid assert value is None + + +@patch("data_designer.cli.utils.shutil.which") +@patch("data_designer.cli.utils.get_ngc_version") +def test_check_ngc_cli_available_returns_true(mock_get_version: MagicMock, mock_which: MagicMock) -> None: + """Test NGC CLI availability check when NGC CLI is installed.""" + mock_which.return_value = "/usr/local/bin/ngc" + mock_get_version.return_value = "NGC CLI 3.41.4" + + assert check_ngc_cli_available() is True + mock_which.assert_called_once_with("ngc") + + +@patch("data_designer.cli.utils.shutil.which") +def test_check_ngc_cli_available_returns_false(mock_which: MagicMock) -> None: + """Test NGC CLI availability check when NGC CLI is not installed.""" + mock_which.return_value = None + + assert check_ngc_cli_available() is False + mock_which.assert_called_once_with("ngc") + + +@patch("data_designer.cli.utils.subprocess.run") +def test_get_ngc_version_success(mock_run: MagicMock) -> None: + """Test getting NGC CLI version successfully.""" + mock_result = MagicMock() + mock_result.stdout = "NGC CLI 3.41.4\n" + mock_run.return_value = mock_result + + version = get_ngc_version() + + assert version == "NGC CLI 3.41.4" + mock_run.assert_called_once_with( + ["ngc", "--version"], + capture_output=True, + text=True, + check=True, + timeout=5, + ) + + +@patch("data_designer.cli.utils.subprocess.run") +def test_get_ngc_version_handles_error(mock_run: MagicMock) -> None: + """Test getting NGC CLI version when command fails.""" + mock_run.side_effect = subprocess.CalledProcessError(1, "ngc") + + version = get_ngc_version() + + assert version is None + + +@patch("data_designer.cli.utils.subprocess.run") +def test_get_ngc_version_handles_timeout(mock_run: MagicMock) -> None: + """Test getting NGC CLI version when command times out.""" + mock_run.side_effect = subprocess.TimeoutExpired("ngc", 5) + + version = get_ngc_version() + + assert version is None + + +@patch("data_designer.cli.utils.subprocess.run") +def test_get_ngc_version_handles_file_not_found(mock_run: MagicMock) -> None: + """Test getting NGC CLI version when NGC CLI is not found.""" + mock_run.side_effect = FileNotFoundError() + + version = get_ngc_version() + + assert version is None