Skip to content

Commit abb5d62

Browse files
authored
feat: Add download personas command to the CLI (#146)
* initial implementation * make download app consistent with config app * add dry run * add unit tests * some cleanup * some refactoring; add dataset sizes to log messages * add list option * update so we don't double print NGC CLI * move constants to shared config file * add persona repository
1 parent 3104ae1 commit abb5d62

File tree

17 files changed

+1637
-12
lines changed

17 files changed

+1637
-12
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import typer
5+
6+
from data_designer.cli.controllers.download_controller import DownloadController
7+
from data_designer.config.utils.constants import DATA_DESIGNER_HOME
8+
9+
10+
def personas_command(
11+
locales: list[str] = typer.Option(
12+
None,
13+
"--locale",
14+
"-l",
15+
help="Locales to download (en_US, en_IN, hi_Deva_IN, hi_Latn_IN, ja_JP). Can be specified multiple times.",
16+
),
17+
all_locales: bool = typer.Option(
18+
False,
19+
"--all",
20+
help="Download all available locales",
21+
),
22+
dry_run: bool = typer.Option(
23+
False,
24+
"--dry-run",
25+
help="Show what would be downloaded without actually downloading",
26+
),
27+
list_available: bool = typer.Option(
28+
False,
29+
"--list",
30+
help="List available persona datasets and their sizes",
31+
),
32+
) -> None:
33+
"""Download Nemotron-Personas datasets for synthetic data generation.
34+
35+
Examples:
36+
# List available datasets
37+
data-designer download personas --list
38+
39+
# Interactive selection
40+
data-designer download personas
41+
42+
# Download specific locales
43+
data-designer download personas --locale en_US --locale ja_JP
44+
45+
# Download all available locales
46+
data-designer download personas --all
47+
48+
# Preview what would be downloaded
49+
data-designer download personas --all --dry-run
50+
"""
51+
controller = DownloadController(DATA_DESIGNER_HOME)
52+
53+
if list_available:
54+
controller.list_personas()
55+
else:
56+
controller.run_personas(locales=locales, all_locales=all_locales, dry_run=dry_run)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from data_designer.cli.controllers.download_controller import DownloadController
45
from data_designer.cli.controllers.model_controller import ModelController
56
from data_designer.cli.controllers.provider_controller import ProviderController
67

7-
__all__ = ["ModelController", "ProviderController"]
8+
__all__ = ["DownloadController", "ModelController", "ProviderController"]
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import subprocess
5+
from pathlib import Path
6+
7+
from data_designer.cli.repositories.persona_repository import PersonaRepository
8+
from data_designer.cli.services.download_service import DownloadService
9+
from data_designer.cli.ui import (
10+
confirm_action,
11+
console,
12+
print_error,
13+
print_header,
14+
print_info,
15+
print_success,
16+
print_text,
17+
select_multiple_with_arrows,
18+
)
19+
from data_designer.cli.utils import check_ngc_cli_available, get_ngc_version
20+
21+
NGC_URL = "https://catalog.ngc.nvidia.com/"
22+
NGC_CLI_INSTALL_URL = "https://org.ngc.nvidia.com/setup/installers/cli"
23+
24+
25+
class DownloadController:
26+
"""Controller for asset download workflows."""
27+
28+
def __init__(self, config_dir: Path):
29+
self.config_dir = config_dir
30+
self.persona_repository = PersonaRepository()
31+
self.service = DownloadService(config_dir, self.persona_repository)
32+
33+
def list_personas(self) -> None:
34+
"""List available persona datasets and their sizes."""
35+
print_header("Available Nemotron-Persona Datasets")
36+
console.print()
37+
38+
available_locales = self.persona_repository.list_all()
39+
40+
print_text("📦 Available locales:")
41+
console.print()
42+
43+
for locale in available_locales:
44+
already_downloaded = self.service.is_locale_downloaded(locale.code)
45+
status = " (downloaded)" if already_downloaded else ""
46+
print_text(f" • {locale.code}: {locale.size}{status}")
47+
48+
console.print()
49+
print_info(f"Total: {len(available_locales)} datasets available")
50+
51+
def run_personas(self, locales: list[str] | None, all_locales: bool, dry_run: bool = False) -> None:
52+
"""Main entry point for persona dataset downloads.
53+
54+
Args:
55+
locales: List of locale codes to download (if provided via CLI flags)
56+
all_locales: If True, download all available locales
57+
dry_run: If True, only show what would be downloaded without actually downloading
58+
"""
59+
header = "Download Nemotron-Persona Datasets (Dry Run)" if dry_run else "Download Nemotron-Persona Datasets"
60+
print_header(header)
61+
print_info(f"Datasets will be saved to: {self.service.get_managed_assets_directory()}")
62+
console.print()
63+
64+
# Check NGC CLI availability (skip checking in dry run mode)
65+
if not dry_run and not check_ngc_cli_with_instructions():
66+
return
67+
68+
# Determine which locales to download
69+
selected_locales = self._determine_locales(locales, all_locales)
70+
71+
if not selected_locales:
72+
print_info("No locales selected")
73+
return
74+
75+
# Show what will be downloaded
76+
console.print()
77+
action = "Would download" if dry_run else "Will download"
78+
print_text(f"📦 {action} {len(selected_locales)} Nemotron-Persona dataset(s):")
79+
for locale_code in selected_locales:
80+
locale = self.persona_repository.get_by_code(locale_code)
81+
already_downloaded = self.service.is_locale_downloaded(locale_code)
82+
status = " - already exists, will update" if already_downloaded else ""
83+
size = locale.size if locale else "unknown"
84+
print_text(f" • {locale_code} ({size}){status}")
85+
86+
console.print()
87+
88+
# In dry run mode, exit here
89+
if dry_run:
90+
print_info("Dry run complete - no files were downloaded")
91+
return
92+
93+
# Confirm download
94+
if not confirm_action("Proceed with download?", default=True):
95+
print_info("Download cancelled")
96+
return
97+
98+
# Download each locale
99+
console.print()
100+
successful = []
101+
failed = []
102+
103+
for locale in selected_locales:
104+
if self._download_locale(locale):
105+
successful.append(locale)
106+
else:
107+
failed.append(locale)
108+
109+
# Summary
110+
console.print()
111+
if successful:
112+
print_success(f"Successfully downloaded {len(successful)} dataset(s): {', '.join(successful)}")
113+
print_info(f"Saved datasets to: {self.service.get_managed_assets_directory()}")
114+
115+
if failed:
116+
print_error(f"Failed to download {len(failed)} dataset(s): {', '.join(failed)}")
117+
118+
def _determine_locales(self, locales: list[str] | None, all_locales: bool) -> list[str]:
119+
"""Determine which locales to download based on user input.
120+
121+
Args:
122+
locales: List of locales from CLI flags (may be None)
123+
all_locales: Whether to download all locales
124+
125+
Returns:
126+
List of locale codes to download
127+
"""
128+
available_locales = self.service.get_available_locales()
129+
130+
# If --all flag is set, return all locales
131+
if all_locales:
132+
return list(available_locales.keys())
133+
134+
# If locales specified via flags, validate and return them
135+
if locales:
136+
invalid_locales = [loc for loc in locales if loc not in available_locales]
137+
if invalid_locales:
138+
print_error(f"Invalid locale(s): {', '.join(invalid_locales)}")
139+
print_info(f"Available locales: {', '.join(available_locales.keys())}")
140+
return []
141+
return locales
142+
143+
# Interactive multi-select
144+
return self._select_locales_interactive(available_locales)
145+
146+
def _select_locales_interactive(self, available_locales: dict[str, str]) -> list[str]:
147+
"""Interactive multi-select for locales.
148+
149+
Args:
150+
available_locales: Dictionary of {locale_code: description}
151+
152+
Returns:
153+
List of selected locale codes
154+
"""
155+
console.print()
156+
print_text("Select locales you want to download:")
157+
console.print()
158+
159+
selected = select_multiple_with_arrows(
160+
options=available_locales,
161+
prompt_text="Use ↑/↓ to navigate, Space to toggle ✓, Enter to confirm:",
162+
default_keys=None,
163+
allow_empty=False,
164+
)
165+
166+
return selected if selected else []
167+
168+
def _download_locale(self, locale: str) -> bool:
169+
"""Download a single locale using NGC CLI.
170+
171+
Args:
172+
locale: Locale code to download
173+
174+
Returns:
175+
True if download succeeded, False otherwise
176+
"""
177+
# Print header before download (NGC CLI will show its own progress)
178+
print_text(f"📦 Downloading Nemotron-Persona dataset for {locale}...")
179+
console.print()
180+
181+
try:
182+
self.service.download_persona_dataset(locale)
183+
console.print()
184+
print_success(f"✓ Downloaded Nemotron-Persona dataset for {locale}")
185+
return True
186+
187+
except subprocess.CalledProcessError as e:
188+
console.print()
189+
print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}")
190+
print_error(f"NGC CLI error: {e}")
191+
return False
192+
193+
except Exception as e:
194+
console.print()
195+
print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}")
196+
print_error(f"Unexpected error: {e}")
197+
return False
198+
199+
200+
def check_ngc_cli_with_instructions() -> bool:
201+
"""Check if NGC CLI is installed and guide user if not."""
202+
if check_ngc_cli_available():
203+
version = get_ngc_version()
204+
if version:
205+
print_info(version)
206+
return True
207+
208+
print_error("NGC CLI not found!")
209+
console.print()
210+
print_text("The NGC CLI is required to download the Nemotron-Personas datasets.")
211+
console.print()
212+
print_text("To download the Nemotron-Personas datasets, follow these steps:")
213+
print_text(f" 1. Create an NVIDIA NGC account: {NGC_URL}")
214+
print_text(f" 2. Install the NGC CLI: {NGC_CLI_INSTALL_URL}")
215+
print_text(" 3. Following the install instructions to set up the NGC CLI")
216+
print_text(" 4. Run 'data-designer download personas'")
217+
return False

src/data_designer/cli/main.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import typer
55

6+
from data_designer.cli.commands import download, models, providers, reset
67
from data_designer.cli.commands import list as list_cmd
7-
from data_designer.cli.commands import models, providers, reset
88
from data_designer.config.default_model_settings import resolve_seed_default_model_settings
99
from data_designer.config.utils.misc import can_run_data_designer_locally
1010

@@ -32,7 +32,17 @@
3232
config_app.command(name="list", help="List current configurations")(list_cmd.list_command)
3333
config_app.command(name="reset", help="Reset configuration files")(reset.reset_command)
3434

35+
# Create download command group
36+
download_app = typer.Typer(
37+
name="download",
38+
help="Download assets for Data Designer",
39+
no_args_is_help=True,
40+
)
41+
download_app.command(name="personas", help="Download Nemotron-Persona datasets")(download.personas_command)
42+
43+
# Add command groups to main app
3544
app.add_typer(config_app, name="config")
45+
app.add_typer(download_app, name="download")
3646

3747

3848
def main() -> None:

0 commit comments

Comments
 (0)