|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import subprocess |
| 5 | +from pathlib import Path |
| 6 | + |
| 7 | +from data_designer.cli.repositories.persona_repository import PersonaRepository |
| 8 | +from data_designer.cli.services.download_service import DownloadService |
| 9 | +from data_designer.cli.ui import ( |
| 10 | + confirm_action, |
| 11 | + console, |
| 12 | + print_error, |
| 13 | + print_header, |
| 14 | + print_info, |
| 15 | + print_success, |
| 16 | + print_text, |
| 17 | + select_multiple_with_arrows, |
| 18 | +) |
| 19 | +from data_designer.cli.utils import check_ngc_cli_available, get_ngc_version |
| 20 | + |
| 21 | +NGC_URL = "https://catalog.ngc.nvidia.com/" |
| 22 | +NGC_CLI_INSTALL_URL = "https://org.ngc.nvidia.com/setup/installers/cli" |
| 23 | + |
| 24 | + |
| 25 | +class DownloadController: |
| 26 | + """Controller for asset download workflows.""" |
| 27 | + |
| 28 | + def __init__(self, config_dir: Path): |
| 29 | + self.config_dir = config_dir |
| 30 | + self.persona_repository = PersonaRepository() |
| 31 | + self.service = DownloadService(config_dir, self.persona_repository) |
| 32 | + |
| 33 | + def list_personas(self) -> None: |
| 34 | + """List available persona datasets and their sizes.""" |
| 35 | + print_header("Available Nemotron-Persona Datasets") |
| 36 | + console.print() |
| 37 | + |
| 38 | + available_locales = self.persona_repository.list_all() |
| 39 | + |
| 40 | + print_text("📦 Available locales:") |
| 41 | + console.print() |
| 42 | + |
| 43 | + for locale in available_locales: |
| 44 | + already_downloaded = self.service.is_locale_downloaded(locale.code) |
| 45 | + status = " (downloaded)" if already_downloaded else "" |
| 46 | + print_text(f" • {locale.code}: {locale.size}{status}") |
| 47 | + |
| 48 | + console.print() |
| 49 | + print_info(f"Total: {len(available_locales)} datasets available") |
| 50 | + |
| 51 | + def run_personas(self, locales: list[str] | None, all_locales: bool, dry_run: bool = False) -> None: |
| 52 | + """Main entry point for persona dataset downloads. |
| 53 | +
|
| 54 | + Args: |
| 55 | + locales: List of locale codes to download (if provided via CLI flags) |
| 56 | + all_locales: If True, download all available locales |
| 57 | + dry_run: If True, only show what would be downloaded without actually downloading |
| 58 | + """ |
| 59 | + header = "Download Nemotron-Persona Datasets (Dry Run)" if dry_run else "Download Nemotron-Persona Datasets" |
| 60 | + print_header(header) |
| 61 | + print_info(f"Datasets will be saved to: {self.service.get_managed_assets_directory()}") |
| 62 | + console.print() |
| 63 | + |
| 64 | + # Check NGC CLI availability (skip checking in dry run mode) |
| 65 | + if not dry_run and not check_ngc_cli_with_instructions(): |
| 66 | + return |
| 67 | + |
| 68 | + # Determine which locales to download |
| 69 | + selected_locales = self._determine_locales(locales, all_locales) |
| 70 | + |
| 71 | + if not selected_locales: |
| 72 | + print_info("No locales selected") |
| 73 | + return |
| 74 | + |
| 75 | + # Show what will be downloaded |
| 76 | + console.print() |
| 77 | + action = "Would download" if dry_run else "Will download" |
| 78 | + print_text(f"📦 {action} {len(selected_locales)} Nemotron-Persona dataset(s):") |
| 79 | + for locale_code in selected_locales: |
| 80 | + locale = self.persona_repository.get_by_code(locale_code) |
| 81 | + already_downloaded = self.service.is_locale_downloaded(locale_code) |
| 82 | + status = " - already exists, will update" if already_downloaded else "" |
| 83 | + size = locale.size if locale else "unknown" |
| 84 | + print_text(f" • {locale_code} ({size}){status}") |
| 85 | + |
| 86 | + console.print() |
| 87 | + |
| 88 | + # In dry run mode, exit here |
| 89 | + if dry_run: |
| 90 | + print_info("Dry run complete - no files were downloaded") |
| 91 | + return |
| 92 | + |
| 93 | + # Confirm download |
| 94 | + if not confirm_action("Proceed with download?", default=True): |
| 95 | + print_info("Download cancelled") |
| 96 | + return |
| 97 | + |
| 98 | + # Download each locale |
| 99 | + console.print() |
| 100 | + successful = [] |
| 101 | + failed = [] |
| 102 | + |
| 103 | + for locale in selected_locales: |
| 104 | + if self._download_locale(locale): |
| 105 | + successful.append(locale) |
| 106 | + else: |
| 107 | + failed.append(locale) |
| 108 | + |
| 109 | + # Summary |
| 110 | + console.print() |
| 111 | + if successful: |
| 112 | + print_success(f"Successfully downloaded {len(successful)} dataset(s): {', '.join(successful)}") |
| 113 | + print_info(f"Saved datasets to: {self.service.get_managed_assets_directory()}") |
| 114 | + |
| 115 | + if failed: |
| 116 | + print_error(f"Failed to download {len(failed)} dataset(s): {', '.join(failed)}") |
| 117 | + |
| 118 | + def _determine_locales(self, locales: list[str] | None, all_locales: bool) -> list[str]: |
| 119 | + """Determine which locales to download based on user input. |
| 120 | +
|
| 121 | + Args: |
| 122 | + locales: List of locales from CLI flags (may be None) |
| 123 | + all_locales: Whether to download all locales |
| 124 | +
|
| 125 | + Returns: |
| 126 | + List of locale codes to download |
| 127 | + """ |
| 128 | + available_locales = self.service.get_available_locales() |
| 129 | + |
| 130 | + # If --all flag is set, return all locales |
| 131 | + if all_locales: |
| 132 | + return list(available_locales.keys()) |
| 133 | + |
| 134 | + # If locales specified via flags, validate and return them |
| 135 | + if locales: |
| 136 | + invalid_locales = [loc for loc in locales if loc not in available_locales] |
| 137 | + if invalid_locales: |
| 138 | + print_error(f"Invalid locale(s): {', '.join(invalid_locales)}") |
| 139 | + print_info(f"Available locales: {', '.join(available_locales.keys())}") |
| 140 | + return [] |
| 141 | + return locales |
| 142 | + |
| 143 | + # Interactive multi-select |
| 144 | + return self._select_locales_interactive(available_locales) |
| 145 | + |
| 146 | + def _select_locales_interactive(self, available_locales: dict[str, str]) -> list[str]: |
| 147 | + """Interactive multi-select for locales. |
| 148 | +
|
| 149 | + Args: |
| 150 | + available_locales: Dictionary of {locale_code: description} |
| 151 | +
|
| 152 | + Returns: |
| 153 | + List of selected locale codes |
| 154 | + """ |
| 155 | + console.print() |
| 156 | + print_text("Select locales you want to download:") |
| 157 | + console.print() |
| 158 | + |
| 159 | + selected = select_multiple_with_arrows( |
| 160 | + options=available_locales, |
| 161 | + prompt_text="Use ↑/↓ to navigate, Space to toggle ✓, Enter to confirm:", |
| 162 | + default_keys=None, |
| 163 | + allow_empty=False, |
| 164 | + ) |
| 165 | + |
| 166 | + return selected if selected else [] |
| 167 | + |
| 168 | + def _download_locale(self, locale: str) -> bool: |
| 169 | + """Download a single locale using NGC CLI. |
| 170 | +
|
| 171 | + Args: |
| 172 | + locale: Locale code to download |
| 173 | +
|
| 174 | + Returns: |
| 175 | + True if download succeeded, False otherwise |
| 176 | + """ |
| 177 | + # Print header before download (NGC CLI will show its own progress) |
| 178 | + print_text(f"📦 Downloading Nemotron-Persona dataset for {locale}...") |
| 179 | + console.print() |
| 180 | + |
| 181 | + try: |
| 182 | + self.service.download_persona_dataset(locale) |
| 183 | + console.print() |
| 184 | + print_success(f"✓ Downloaded Nemotron-Persona dataset for {locale}") |
| 185 | + return True |
| 186 | + |
| 187 | + except subprocess.CalledProcessError as e: |
| 188 | + console.print() |
| 189 | + print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}") |
| 190 | + print_error(f"NGC CLI error: {e}") |
| 191 | + return False |
| 192 | + |
| 193 | + except Exception as e: |
| 194 | + console.print() |
| 195 | + print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}") |
| 196 | + print_error(f"Unexpected error: {e}") |
| 197 | + return False |
| 198 | + |
| 199 | + |
| 200 | +def check_ngc_cli_with_instructions() -> bool: |
| 201 | + """Check if NGC CLI is installed and guide user if not.""" |
| 202 | + if check_ngc_cli_available(): |
| 203 | + version = get_ngc_version() |
| 204 | + if version: |
| 205 | + print_info(version) |
| 206 | + return True |
| 207 | + |
| 208 | + print_error("NGC CLI not found!") |
| 209 | + console.print() |
| 210 | + print_text("The NGC CLI is required to download the Nemotron-Personas datasets.") |
| 211 | + console.print() |
| 212 | + print_text("To download the Nemotron-Personas datasets, follow these steps:") |
| 213 | + print_text(f" 1. Create an NVIDIA NGC account: {NGC_URL}") |
| 214 | + print_text(f" 2. Install the NGC CLI: {NGC_CLI_INSTALL_URL}") |
| 215 | + print_text(" 3. Following the install instructions to set up the NGC CLI") |
| 216 | + print_text(" 4. Run 'data-designer download personas'") |
| 217 | + return False |
0 commit comments