Skip to content

Commit e1fde41

Browse files
committed
deleting providers should delete associated model configs
1 parent d4e7e57 commit e1fde41

File tree

4 files changed

+189
-16
lines changed

4 files changed

+189
-16
lines changed

src/data_designer/cli/controllers/provider_controller.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from pathlib import Path
66

77
from data_designer.cli.forms.provider_builder import ProviderFormBuilder
8+
from data_designer.cli.repositories.model_repository import ModelRepository
89
from data_designer.cli.repositories.provider_repository import ProviderRepository
10+
from data_designer.cli.services.model_service import ModelService
911
from data_designer.cli.services.provider_service import ProviderService
1012
from data_designer.cli.ui import (
1113
confirm_action,
@@ -15,6 +17,7 @@
1517
print_header,
1618
print_info,
1719
print_success,
20+
print_warning,
1821
select_with_arrows,
1922
)
2023
from data_designer.engine.model_provider import ModelProvider
@@ -27,6 +30,8 @@ def __init__(self, config_dir: Path):
2730
self.config_dir = config_dir
2831
self.repository = ProviderRepository(config_dir)
2932
self.service = ProviderService(self.repository)
33+
self.model_repository = ModelRepository(config_dir)
34+
self.model_service = ModelService(self.model_repository)
3035

3136
def run(self) -> None:
3237
"""Main entry point for provider configuration."""
@@ -182,14 +187,45 @@ def _handle_delete(self) -> None:
182187
if selected_name is None:
183188
return
184189

190+
# Check for associated models
191+
associated_models = self.model_service.find_by_provider(selected_name)
192+
185193
# Confirm deletion
186194
console.print()
187-
if confirm_action(f"Delete provider '{selected_name}'?", default=False):
188-
try:
189-
self.service.delete(selected_name)
190-
print_success(f"Provider '{selected_name}' deleted successfully")
191-
except ValueError as e:
192-
print_error(f"Failed to delete provider: {e}")
195+
196+
if associated_models:
197+
# Notify user about associated models
198+
model_count = len(associated_models)
199+
model_aliases = ", ".join([f"'{m.alias}'" for m in associated_models])
200+
201+
print_warning(f"Provider '{selected_name}' has {model_count} associated model config(s): {model_aliases}")
202+
console.print()
203+
204+
# Ask if user wants to delete provider and associated models
205+
if confirm_action(
206+
f"Delete provider '{selected_name}' and its {model_count} associated model config(s)?", default=False
207+
):
208+
try:
209+
# Delete associated models first
210+
model_aliases_to_delete = [m.alias for m in associated_models]
211+
self.model_service.delete_by_aliases(model_aliases_to_delete)
212+
213+
# Then delete the provider
214+
self.service.delete(selected_name)
215+
216+
print_success(
217+
f"Provider '{selected_name}' and {model_count} associated model(s) deleted successfully"
218+
)
219+
except ValueError as e:
220+
print_error(f"Failed to delete provider and associated models: {e}")
221+
else:
222+
# No associated models, proceed with simple deletion
223+
if confirm_action(f"Delete provider '{selected_name}'?", default=False):
224+
try:
225+
self.service.delete(selected_name)
226+
print_success(f"Provider '{selected_name}' deleted successfully")
227+
except ValueError as e:
228+
print_error(f"Failed to delete provider: {e}")
193229

194230
def _handle_delete_all(self) -> None:
195231
"""Handle deleting all providers."""
@@ -198,20 +234,48 @@ def _handle_delete_all(self) -> None:
198234
print_error("No providers to delete")
199235
return
200236

237+
# Check for associated models across all providers
238+
all_models = self.model_service.list_all()
239+
provider_names_set = {p.name for p in providers}
240+
associated_models = [m for m in all_models if m.provider in provider_names_set]
241+
201242
# List providers to be deleted
202243
console.print()
203244
provider_count = len(providers)
204245
provider_names = ", ".join([f"'{p.name}'" for p in providers])
205246

206-
if confirm_action(
207-
f"⚠️ Delete ALL ({provider_count}) provider(s): {provider_names}?\n This action cannot be undone.",
208-
default=False,
209-
):
210-
try:
211-
self.repository.delete()
212-
print_success(f"All ({provider_count}) provider(s) deleted successfully")
213-
except Exception as e:
214-
print_error(f"Failed to delete all providers: {e}")
247+
if associated_models:
248+
model_count = len(associated_models)
249+
print_warning(f"Deleting all providers will also affect {model_count} associated model config(s)")
250+
console.print()
251+
252+
if confirm_action(
253+
f"⚠️ Delete ALL ({provider_count}) provider(s): {provider_names} and {model_count} associated model(s)?\n This action cannot be undone.",
254+
default=False,
255+
):
256+
try:
257+
# Delete all associated models first
258+
model_aliases_to_delete = [m.alias for m in associated_models]
259+
self.model_service.delete_by_aliases(model_aliases_to_delete)
260+
261+
# Then delete all providers
262+
self.repository.delete()
263+
264+
print_success(
265+
f"All ({provider_count}) provider(s) and {model_count} associated model(s) deleted successfully"
266+
)
267+
except Exception as e:
268+
print_error(f"Failed to delete all providers and associated models: {e}")
269+
else:
270+
if confirm_action(
271+
f"⚠️ Delete ALL ({provider_count}) provider(s): {provider_names}?\n This action cannot be undone.",
272+
default=False,
273+
):
274+
try:
275+
self.repository.delete()
276+
print_success(f"All ({provider_count}) provider(s) deleted successfully")
277+
except Exception as e:
278+
print_error(f"Failed to delete all providers: {e}")
215279

216280
def _handle_change_default(self) -> None:
217281
"""Handle changing the default provider."""

src/data_designer/cli/services/model_service.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ def get_by_alias(self, alias: str) -> ModelConfig | None:
2121
models = self.list_all()
2222
return next((m for m in models if m.alias == alias), None)
2323

24+
def find_by_provider(self, provider_name: str) -> list[ModelConfig]:
25+
"""Find all models associated with a provider.
26+
27+
Args:
28+
provider_name: Name of the provider to search for
29+
30+
Returns:
31+
List of models associated with the provider
32+
"""
33+
models = self.list_all()
34+
return [m for m in models if m.provider == provider_name]
35+
2436
def add(self, model: ModelConfig) -> None:
2537
"""Add a new model.
2638
@@ -79,3 +91,26 @@ def delete(self, alias: str) -> None:
7991
self.repository.save(registry)
8092
else:
8193
self.repository.delete()
94+
95+
def delete_by_aliases(self, aliases: list[str]) -> None:
96+
"""Delete multiple models by their aliases.
97+
98+
Args:
99+
aliases: List of model aliases to delete
100+
101+
Raises:
102+
ValueError: If no models configured
103+
"""
104+
if not aliases:
105+
return
106+
107+
registry = self.repository.load()
108+
if not registry:
109+
raise ValueError("No models configured")
110+
111+
registry.model_configs = [m for m in registry.model_configs if m.alias not in aliases]
112+
113+
if registry.model_configs:
114+
self.repository.save(registry)
115+
else:
116+
self.repository.delete()

tests/cli/services/test_model_service.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,32 @@ def test_update(
3535
def test_delete(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]):
3636
stub_model_service.delete("test-alias-1")
3737
assert stub_model_service.list_all() == stub_model_configs[1:]
38+
39+
40+
def test_find_by_provider(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]):
41+
# Both test models have provider="test-provider-1"
42+
models = stub_model_service.find_by_provider("test-provider-1")
43+
assert len(models) == 2
44+
assert models == stub_model_configs
45+
46+
# Non-existent provider should return empty list
47+
models = stub_model_service.find_by_provider("non-existent-provider")
48+
assert models == []
49+
50+
51+
def test_delete_by_aliases(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]):
52+
# Delete both models
53+
stub_model_service.delete_by_aliases(["test-alias-1", "test-alias-2"])
54+
assert stub_model_service.list_all() == []
55+
56+
57+
def test_delete_by_aliases_partial(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]):
58+
# Delete only one model
59+
stub_model_service.delete_by_aliases(["test-alias-1"])
60+
assert stub_model_service.list_all() == stub_model_configs[1:]
61+
62+
63+
def test_delete_by_aliases_empty_list(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]):
64+
# Deleting empty list should do nothing
65+
stub_model_service.delete_by_aliases([])
66+
assert stub_model_service.list_all() == stub_model_configs

tests/cli/services/test_provider_service.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from pathlib import Path
5+
6+
from data_designer.cli.repositories.model_repository import ModelConfigRegistry, ModelRepository
7+
from data_designer.cli.services.model_service import ModelService
48
from data_designer.cli.services.provider_service import ProviderService
5-
from data_designer.config.models import ModelProvider
9+
from data_designer.config.models import ModelConfig, ModelProvider
610

711

812
def test_list_all(stub_provider_service: ProviderService, stub_model_providers: list[ModelProvider]):
@@ -41,3 +45,44 @@ def test_set_default(stub_provider_service: ProviderService, stub_model_provider
4145

4246
def test_get_default(stub_provider_service: ProviderService, stub_model_providers: list[ModelProvider]):
4347
assert stub_provider_service.get_default() == "test-provider-1"
48+
49+
50+
def test_delete_provider_with_associated_models(
51+
tmp_path: Path, stub_model_providers: list[ModelProvider], stub_model_configs: list[ModelConfig]
52+
):
53+
"""Test integration: deleting a provider and its associated models."""
54+
# Setup: Create provider and model services
55+
provider_service = ProviderService(
56+
ModelRepository(tmp_path) # This should be ProviderRepository
57+
)
58+
model_service = ModelService(ModelRepository(tmp_path))
59+
60+
# Save provider and models
61+
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
62+
63+
provider_repo = ProviderRepository(tmp_path)
64+
provider_repo.save(ModelProviderRegistry(providers=stub_model_providers, default="test-provider-1"))
65+
66+
model_repo = ModelRepository(tmp_path)
67+
model_repo.save(ModelConfigRegistry(model_configs=stub_model_configs))
68+
69+
# Verify initial state: 2 providers, 2 models (both associated with test-provider-1)
70+
provider_service = ProviderService(provider_repo)
71+
assert len(provider_service.list_all()) == 2
72+
assert len(model_service.list_all()) == 2
73+
74+
# Find models associated with test-provider-1
75+
associated_models = model_service.find_by_provider("test-provider-1")
76+
assert len(associated_models) == 2
77+
78+
# Delete the associated models
79+
model_aliases = [m.alias for m in associated_models]
80+
model_service.delete_by_aliases(model_aliases)
81+
82+
# Delete the provider
83+
provider_service.delete("test-provider-1")
84+
85+
# Verify final state: 1 provider, 0 models
86+
assert len(provider_service.list_all()) == 1
87+
assert len(model_service.list_all()) == 0
88+
assert provider_service.get_by_name("test-provider-1") is None

0 commit comments

Comments
 (0)