Skip to content

Commit 26834b9

Browse files
committed
some refactoring; add dataset sizes to log messages
1 parent c93c731 commit 26834b9

File tree

6 files changed

+176
-150
lines changed

6 files changed

+176
-150
lines changed

src/data_designer/cli/controllers/download_controller.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import subprocess
55
from pathlib import Path
66

7-
from data_designer.cli.services.download_service import DownloadService
7+
from data_designer.cli.services.download_service import DATASET_SIZES, DownloadService
88
from data_designer.cli.ui import (
99
confirm_action,
1010
console,
@@ -15,6 +15,7 @@
1515
print_text,
1616
select_multiple_with_arrows,
1717
)
18+
from data_designer.cli.utils import check_ngc_cli_available, get_ngc_version
1819

1920
NGC_URL = "https://catalog.ngc.nvidia.com/"
2021
NGC_CLI_INSTALL_URL = "https://org.ngc.nvidia.com/setup/installers/cli"
@@ -40,8 +41,8 @@ def run_personas(self, locales: list[str] | None, all_locales: bool, dry_run: bo
4041
print_info(f"Datasets will be saved to: {self.service.get_managed_assets_directory()}")
4142
console.print()
4243

43-
# Check NGC CLI availability (skip in dry run mode)
44-
if not dry_run and not self._check_ngc_cli():
44+
# Check NGC CLI availability (skip checking in dry run mode)
45+
if not dry_run and not check_ngc_cli_with_instructions():
4546
return
4647

4748
# Determine which locales to download
@@ -57,8 +58,8 @@ def run_personas(self, locales: list[str] | None, all_locales: bool, dry_run: bo
5758
print_text(f"📦 {action} {len(selected_locales)} Nemotron-Persona dataset(s):")
5859
for locale in selected_locales:
5960
already_downloaded = self.service.is_locale_downloaded(locale)
60-
status = " (already exists, will update)" if already_downloaded else ""
61-
print_text(f" • {locale}{status}")
61+
status = " - already exists, will update" if already_downloaded else ""
62+
print_text(f" • {locale} ({DATASET_SIZES[locale]}){status}")
6263

6364
console.print()
6465

@@ -87,29 +88,11 @@ def run_personas(self, locales: list[str] | None, all_locales: bool, dry_run: bo
8788
console.print()
8889
if successful:
8990
print_success(f"Successfully downloaded {len(successful)} dataset(s): {', '.join(successful)}")
90-
print_info(f"Location: {self.service.get_managed_assets_directory()}")
91+
print_info(f"Saved datasets to: {self.service.get_managed_assets_directory()}")
9192

9293
if failed:
9394
print_error(f"Failed to download {len(failed)} dataset(s): {', '.join(failed)}")
9495

95-
def _check_ngc_cli(self) -> bool:
96-
"""Check if NGC CLI is installed and guide user if not."""
97-
if self.service.check_ngc_cli_available():
98-
version = self.service.get_ngc_version()
99-
if version:
100-
print_info(f"NGC CLI: {version}")
101-
return True
102-
103-
print_error("NGC CLI not found!")
104-
console.print()
105-
print_text("The NGC CLI is required to download the Nemotron-Personas datasets.")
106-
console.print()
107-
print_text("To download the Nemotron-Personas datasets, follow these steps:")
108-
print_text(f" 1. Create an NVIDIA NGC account: {NGC_URL}")
109-
print_text(f" 2. Install the NGC CLI: {NGC_CLI_INSTALL_URL}")
110-
print_text(" 3. Following the install instructions to set up the NGC CLI")
111-
print_text(" 4. Run 'data-designer download personas'")
112-
11396
def _determine_locales(self, locales: list[str] | None, all_locales: bool) -> list[str]:
11497
"""Determine which locales to download based on user input.
11598
@@ -190,3 +173,23 @@ def _download_locale(self, locale: str) -> bool:
190173
print_error(f"✗ Failed to download Nemotron-Persona dataset for {locale}")
191174
print_error(f"Unexpected error: {e}")
192175
return False
176+
177+
178+
def check_ngc_cli_with_instructions() -> bool:
179+
"""Check if NGC CLI is installed and guide user if not."""
180+
if check_ngc_cli_available():
181+
version = get_ngc_version()
182+
if version:
183+
print_info(f"NGC CLI: {version}")
184+
return True
185+
186+
print_error("NGC CLI not found!")
187+
console.print()
188+
print_text("The NGC CLI is required to download the Nemotron-Personas datasets.")
189+
console.print()
190+
print_text("To download the Nemotron-Personas datasets, follow these steps:")
191+
print_text(f" 1. Create an NVIDIA NGC account: {NGC_URL}")
192+
print_text(f" 2. Install the NGC CLI: {NGC_CLI_INSTALL_URL}")
193+
print_text(" 3. Following the install instructions to set up the NGC CLI")
194+
print_text(" 4. Run 'data-designer download personas'")
195+
return False

src/data_designer/cli/services/download_service.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77
import tempfile
88
from pathlib import Path
99

10+
DATASET_SIZES = {
11+
"en_US": "1.24 GB",
12+
"en_IN": "2.39 GB",
13+
"hi_Deva_IN": "4.14 GB",
14+
"hi_Latn_IN": "2.7 GB",
15+
"ja_JP": "1.69 GB",
16+
}
17+
SUPPORTED_LOCALES = list[str](DATASET_SIZES.keys())
1018
DATASET_PREFIX = "nemotron-personas-dataset-"
11-
SUPPORTED_LOCALES = ["en_US", "en_IN", "hi_Deva_IN", "hi_Latn_IN", "ja_JP"]
1219

1320

1421
class DownloadService:
@@ -18,31 +25,6 @@ def __init__(self, config_dir: Path):
1825
self.config_dir = config_dir
1926
self.managed_assets_dir = config_dir / "managed-assets" / "datasets"
2027

21-
def check_ngc_cli_available(self) -> bool:
22-
"""Check if NGC CLI is installed and available.
23-
24-
Returns:
25-
True if NGC CLI is in PATH and executable, False otherwise.
26-
"""
27-
if shutil.which("ngc") is None:
28-
return False
29-
30-
return self.get_ngc_version() is not None
31-
32-
def get_ngc_version(self) -> str | None:
33-
"""Get the NGC CLI version if available."""
34-
try:
35-
result = subprocess.run(
36-
["ngc", "--version"],
37-
capture_output=True,
38-
text=True,
39-
check=True,
40-
timeout=5,
41-
)
42-
return result.stdout.strip()
43-
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
44-
return None
45-
4628
def get_available_locales(self) -> dict[str, str]:
4729
"""Get dictionary of available persona locales (locale code -> locale code)."""
4830
return {locale: locale for locale in SUPPORTED_LOCALES}
@@ -73,14 +55,14 @@ def download_persona_dataset(self, locale: str) -> Path:
7355
"registry",
7456
"resource",
7557
"download-version",
76-
f"nvidia/nemotron-personas/{_get_dataset_name(locale)}",
58+
f"nvidia/nemotron-personas/{_get_downloaded_dataset_name(locale)}",
7759
"--dest",
7860
temp_dir,
7961
]
8062

8163
subprocess.run(cmd, check=True)
8264

83-
dataset_pattern = _get_dataset_name(locale)
65+
dataset_pattern = _get_downloaded_dataset_name(locale)
8466
download_pattern = f"{temp_dir}/{dataset_pattern}*/*.parquet"
8567
parquet_files = glob.glob(download_pattern)
8668

@@ -114,16 +96,14 @@ def is_locale_downloaded(self, locale: str) -> bool:
11496
if not self.managed_assets_dir.exists():
11597
return False
11698

117-
# Check for parquet files matching this locale in managed assets
118-
dataset_pattern = _get_dataset_name(locale)
11999
# Look for any parquet files that start with the dataset pattern
120-
parquet_files = glob.glob(str(self.managed_assets_dir / f"{dataset_pattern}*.parquet"))
100+
parquet_files = glob.glob(str(self.managed_assets_dir / f"{locale}.parquet"))
121101

122102
return len(parquet_files) > 0
123103

124104

125-
def _get_dataset_name(locale: str) -> str:
126-
"""Build dataset name pattern for the given locale.
105+
def _get_downloaded_dataset_name(locale: str) -> str:
106+
"""Build the downloaded dataset name pattern for the given locale.
127107
128108
Args:
129109
locale: Locale code (e.g., 'en_US', 'ja_JP')

src/data_designer/cli/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,40 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import shutil
5+
import subprocess
6+
7+
8+
def check_ngc_cli_available() -> bool:
9+
"""Check if NGC CLI is installed and available.
10+
11+
Returns:
12+
True if NGC CLI is in PATH and executable, False otherwise.
13+
"""
14+
if shutil.which("ngc") is None:
15+
return False
16+
17+
return get_ngc_version() is not None
18+
19+
20+
def get_ngc_version() -> str | None:
21+
"""Get the NGC CLI version if available.
22+
23+
Returns:
24+
NGC CLI version string if available, None otherwise.
25+
"""
26+
try:
27+
result = subprocess.run(
28+
["ngc", "--version"],
29+
capture_output=True,
30+
text=True,
31+
check=True,
32+
timeout=5,
33+
)
34+
return result.stdout.strip()
35+
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
36+
return None
37+
438

539
def validate_url(url: str) -> bool:
640
"""Validate that a string is a valid URL.

tests/cli/controllers/test_download_controller.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def controller_with_datasets(tmp_path: Path) -> DownloadController:
2323
# Create managed assets directory with sample parquet files
2424
managed_assets_dir = tmp_path / "managed-assets" / "datasets"
2525
managed_assets_dir.mkdir(parents=True, exist_ok=True)
26-
(managed_assets_dir / "nemotron-personas-dataset-en_us_0.parquet").touch()
26+
(managed_assets_dir / "en_US.parquet").touch()
2727
return controller
2828

2929

@@ -36,7 +36,7 @@ def test_init(tmp_path: Path) -> None:
3636

3737
@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=False)
3838
@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=["en_US"])
39-
@patch.object(DownloadController, "_check_ngc_cli", return_value=True)
39+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True)
4040
def test_run_personas_user_cancels_confirmation(
4141
mock_check_ngc: MagicMock,
4242
mock_select: MagicMock,
@@ -58,7 +58,7 @@ def test_run_personas_user_cancels_confirmation(
5858

5959
@patch.object(DownloadController, "_download_locale", return_value=True)
6060
@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True)
61-
@patch.object(DownloadController, "_check_ngc_cli", return_value=True)
61+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True)
6262
def test_run_personas_with_all_flag(
6363
mock_check_ngc: MagicMock,
6464
mock_confirm: MagicMock,
@@ -85,7 +85,7 @@ def test_run_personas_with_all_flag(
8585

8686
@patch.object(DownloadController, "_download_locale", return_value=True)
8787
@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True)
88-
@patch.object(DownloadController, "_check_ngc_cli", return_value=True)
88+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True)
8989
def test_run_personas_with_specific_locales(
9090
mock_check_ngc: MagicMock,
9191
mock_confirm: MagicMock,
@@ -105,7 +105,7 @@ def test_run_personas_with_specific_locales(
105105
assert "ja_JP" in downloaded_locales
106106

107107

108-
@patch.object(DownloadController, "_check_ngc_cli", return_value=True)
108+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True)
109109
def test_run_personas_with_invalid_locales(
110110
mock_check_ngc: MagicMock,
111111
controller: DownloadController,
@@ -122,7 +122,7 @@ def test_run_personas_with_invalid_locales(
122122
@patch.object(DownloadController, "_download_locale", return_value=True)
123123
@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True)
124124
@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=["en_US"])
125-
@patch.object(DownloadController, "_check_ngc_cli", return_value=True)
125+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True)
126126
def test_run_personas_interactive_selection(
127127
mock_check_ngc: MagicMock,
128128
mock_select: MagicMock,
@@ -147,7 +147,7 @@ def test_run_personas_interactive_selection(
147147

148148

149149
@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=None)
150-
@patch.object(DownloadController, "_check_ngc_cli", return_value=True)
150+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True)
151151
def test_run_personas_interactive_cancelled(
152152
mock_check_ngc: MagicMock,
153153
mock_select: MagicMock,
@@ -165,7 +165,7 @@ def test_run_personas_interactive_cancelled(
165165
# Function should exit early
166166

167167

168-
@patch.object(DownloadController, "_check_ngc_cli", return_value=False)
168+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=False)
169169
def test_run_personas_ngc_cli_not_available(
170170
mock_check_ngc: MagicMock,
171171
controller: DownloadController,
@@ -177,20 +177,24 @@ def test_run_personas_ngc_cli_not_available(
177177
mock_check_ngc.assert_called_once()
178178

179179

180-
def test_check_ngc_cli_available_with_version(controller: DownloadController) -> None:
181-
"""Test _check_ngc_cli displays version when NGC CLI is available."""
182-
with patch.object(controller.service, "check_ngc_cli_available", return_value=True):
183-
with patch.object(controller.service, "get_ngc_version", return_value="NGC CLI 3.41.4"):
184-
result = controller._check_ngc_cli()
180+
def test_check_ngc_cli_available_with_version() -> None:
181+
"""Test check_ngc_cli_with_instructions displays version when NGC CLI is available."""
182+
from data_designer.cli.controllers.download_controller import check_ngc_cli_with_instructions
183+
184+
with patch("data_designer.cli.controllers.download_controller.check_ngc_cli_available", return_value=True):
185+
with patch("data_designer.cli.controllers.download_controller.get_ngc_version", return_value="NGC CLI 3.41.4"):
186+
result = check_ngc_cli_with_instructions()
185187

186188
assert result is True
187189

188190

189-
def test_check_ngc_cli_available_without_version(controller: DownloadController) -> None:
190-
"""Test _check_ngc_cli when version cannot be determined."""
191-
with patch.object(controller.service, "check_ngc_cli_available", return_value=True):
192-
with patch.object(controller.service, "get_ngc_version", return_value=None):
193-
result = controller._check_ngc_cli()
191+
def test_check_ngc_cli_available_without_version() -> None:
192+
"""Test check_ngc_cli_with_instructions when version cannot be determined."""
193+
from data_designer.cli.controllers.download_controller import check_ngc_cli_with_instructions
194+
195+
with patch("data_designer.cli.controllers.download_controller.check_ngc_cli_available", return_value=True):
196+
with patch("data_designer.cli.controllers.download_controller.get_ngc_version", return_value=None):
197+
result = check_ngc_cli_with_instructions()
194198

195199
assert result is True
196200

@@ -276,7 +280,7 @@ def test_download_locale_generic_error(controller: DownloadController) -> None:
276280

277281
@patch.object(DownloadController, "_download_locale")
278282
@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True)
279-
@patch.object(DownloadController, "_check_ngc_cli", return_value=True)
283+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True)
280284
def test_run_personas_mixed_success_and_failure(
281285
mock_check_ngc: MagicMock,
282286
mock_confirm: MagicMock,
@@ -295,7 +299,7 @@ def test_run_personas_mixed_success_and_failure(
295299

296300
@patch.object(DownloadController, "_download_locale", return_value=True)
297301
@patch("data_designer.cli.controllers.download_controller.confirm_action", return_value=True)
298-
@patch.object(DownloadController, "_check_ngc_cli", return_value=True)
302+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions", return_value=True)
299303
def test_run_personas_shows_existing_status(
300304
mock_check_ngc: MagicMock,
301305
mock_confirm: MagicMock,
@@ -311,7 +315,7 @@ def test_run_personas_shows_existing_status(
311315

312316
@patch.object(DownloadController, "_download_locale")
313317
@patch("data_designer.cli.controllers.download_controller.confirm_action")
314-
@patch.object(DownloadController, "_check_ngc_cli")
318+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions")
315319
def test_run_personas_with_dry_run_flag(
316320
mock_check_ngc: MagicMock,
317321
mock_confirm: MagicMock,
@@ -333,7 +337,7 @@ def test_run_personas_with_dry_run_flag(
333337

334338
@patch.object(DownloadController, "_download_locale")
335339
@patch("data_designer.cli.controllers.download_controller.confirm_action")
336-
@patch.object(DownloadController, "_check_ngc_cli")
340+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions")
337341
def test_run_personas_with_all_and_dry_run(
338342
mock_check_ngc: MagicMock,
339343
mock_confirm: MagicMock,
@@ -356,7 +360,7 @@ def test_run_personas_with_all_and_dry_run(
356360
@patch.object(DownloadController, "_download_locale")
357361
@patch("data_designer.cli.controllers.download_controller.confirm_action")
358362
@patch("data_designer.cli.controllers.download_controller.select_multiple_with_arrows", return_value=["en_US"])
359-
@patch.object(DownloadController, "_check_ngc_cli")
363+
@patch("data_designer.cli.controllers.download_controller.check_ngc_cli_with_instructions")
360364
def test_run_personas_interactive_with_dry_run(
361365
mock_check_ngc: MagicMock,
362366
mock_select: MagicMock,

0 commit comments

Comments
 (0)