Skip to content

Commit 933bf3c

Browse files
committed
new tests for cli components
1 parent 35af1be commit 933bf3c

File tree

11 files changed

+272
-61
lines changed

11 files changed

+272
-61
lines changed

src/data_designer/cli/constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,25 @@
99
DEFAULT_CONFIG_DIR = Path.home() / ".data-designer"
1010
MODEL_CONFIGS_FILE_NAME = "model_configs.yaml"
1111
MODEL_PROVIDERS_FILE_NAME = "model_providers.yaml"
12+
13+
# Predefined provider templates
14+
PREDEFINED_PROVIDERS = {
15+
"nvidia": {
16+
"name": "nvidia",
17+
"endpoint": "https://integrate.api.nvidia.com/v1",
18+
"provider_type": "openai",
19+
"api_key": "NVIDIA_API_KEY",
20+
},
21+
"openai": {
22+
"name": "openai",
23+
"endpoint": "https://api.openai.com/v1",
24+
"provider_type": "openai",
25+
"api_key": "OPENAI_API_KEY",
26+
},
27+
"anthropic": {
28+
"name": "anthropic",
29+
"endpoint": "https://api.anthropic.com/v1",
30+
"provider_type": "openai",
31+
"api_key": "ANTHROPIC_API_KEY",
32+
},
33+
}

src/data_designer/cli/forms/provider_builder.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,13 @@
33

44
from typing import Any
55

6+
from data_designer.cli.constants import PREDEFINED_PROVIDERS
67
from data_designer.cli.forms.builder import FormBuilder
78
from data_designer.cli.forms.field import SelectField, TextField
89
from data_designer.cli.forms.form import Form
910
from data_designer.cli.utils import validate_url
1011
from data_designer.engine.model_provider import ModelProvider
1112

12-
# Predefined provider templates
13-
PREDEFINED_PROVIDERS = {
14-
"nvidia": {
15-
"name": "nvidia",
16-
"endpoint": "https://integrate.api.nvidia.com/v1",
17-
"provider_type": "openai",
18-
"api_key": "NVIDIA_API_KEY",
19-
},
20-
"openai": {
21-
"name": "openai",
22-
"endpoint": "https://api.openai.com/v1",
23-
"provider_type": "openai",
24-
"api_key": "OPENAI_API_KEY",
25-
},
26-
"anthropic": {
27-
"name": "anthropic",
28-
"endpoint": "https://api.anthropic.com/v1",
29-
"provider_type": "openai",
30-
"api_key": "ANTHROPIC_API_KEY",
31-
},
32-
}
33-
3413

3514
class ProviderFormBuilder(FormBuilder[ModelProvider]):
3615
"""Builds interactive forms for provider configuration."""

src/data_designer/cli/repositories/model_repository.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,18 @@
33

44
from pathlib import Path
55

6+
from pydantic import BaseModel
7+
68
from data_designer.cli.constants import MODEL_CONFIGS_FILE_NAME
79
from data_designer.cli.repositories.base import ConfigRepository
810
from data_designer.cli.utils import load_config_file, save_config_file
911
from data_designer.config.models import ModelConfig
1012

1113

12-
class ModelConfigRegistry:
14+
class ModelConfigRegistry(BaseModel):
1315
"""Registry for model configurations."""
1416

15-
def __init__(self, model_configs: list[ModelConfig]):
16-
self.model_configs = model_configs
17-
18-
def model_dump(self, **kwargs) -> dict:
19-
"""Dump to dictionary format."""
20-
return {"model_configs": [mc.model_dump(**kwargs) for mc in self.model_configs]}
17+
model_configs: list[ModelConfig]
2118

2219

2320
class ModelRepository(ConfigRepository[ModelConfigRegistry]):
@@ -35,10 +32,7 @@ def load(self) -> ModelConfigRegistry | None:
3532

3633
try:
3734
config_dict = load_config_file(self.config_file)
38-
if "model_configs" not in config_dict:
39-
return None
40-
model_configs = [ModelConfig.model_validate(mc) for mc in config_dict["model_configs"]]
41-
return ModelConfigRegistry(model_configs)
35+
return ModelConfigRegistry.model_validate(config_dict)
4236
except Exception:
4337
return None
4438

src/data_designer/cli/repositories/provider_repository.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@
33

44
from pathlib import Path
55

6+
from pydantic import BaseModel
7+
68
from data_designer.cli.constants import MODEL_PROVIDERS_FILE_NAME
79
from data_designer.cli.repositories.base import ConfigRepository
810
from data_designer.cli.utils import load_config_file, save_config_file
9-
from data_designer.engine.model_provider import ModelProviderRegistry
11+
from data_designer.config.models import ModelProvider
12+
13+
14+
class ModelProviderRegistry(BaseModel):
15+
"""Registry for model provider configurations."""
16+
17+
providers: list[ModelProvider]
18+
default: str | None = None
1019

1120

1221
class ProviderRepository(ConfigRepository[ModelProviderRegistry]):

src/data_designer/cli/services/model_service.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def add(self, model: ModelConfig) -> None:
2929
"""
3030
registry = self.repository.load() or ModelConfigRegistry(model_configs=[])
3131

32-
# Business rule: No duplicate aliases
3332
if any(m.alias == model.alias for m in registry.model_configs):
3433
raise ValueError(f"Model alias '{model.alias}' already exists")
3534

@@ -54,12 +53,10 @@ def update(self, original_alias: str, updated_model: ModelConfig) -> None:
5453
if index is None:
5554
raise ValueError(f"Model '{original_alias}' not found")
5655

57-
# Business rule: Alias change must not conflict
5856
if updated_model.alias != original_alias:
5957
if any(m.alias == updated_model.alias for m in registry.model_configs):
6058
raise ValueError(f"Model alias '{updated_model.alias}' already exists")
6159

62-
# Update
6360
registry.model_configs[index] = updated_model
6461
self.repository.save(registry)
6562

@@ -76,10 +73,8 @@ def delete(self, alias: str) -> None:
7673
if not any(m.alias == alias for m in registry.model_configs):
7774
raise ValueError(f"Model '{alias}' not found")
7875

79-
# Remove model
8076
registry.model_configs = [m for m in registry.model_configs if m.alias != alias]
8177

82-
# Business rule: Delete file if no models left
8378
if registry.model_configs:
8479
self.repository.save(registry)
8580
else:

src/data_designer/cli/services/provider_service.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from data_designer.cli.repositories.provider_repository import ProviderRepository
5-
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
4+
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
5+
from data_designer.config.models import ModelProvider
66

77

88
class ProviderService:
@@ -27,22 +27,12 @@ def add(self, provider: ModelProvider) -> None:
2727
Raises:
2828
ValueError: If provider name already exists
2929
"""
30-
registry = self.repository.load()
31-
32-
if registry:
33-
# Business rule: No duplicate names
34-
if any(p.name == provider.name for p in registry.providers):
35-
raise ValueError(f"Provider '{provider.name}' already exists")
36-
37-
registry.providers.append(provider)
38-
else:
39-
# Create new registry with first provider
40-
registry = ModelProviderRegistry(providers=[provider], default=provider.name)
30+
registry = self.repository.load() or ModelProviderRegistry(providers=[], default=None)
4131

42-
# Business rule: First provider is default (for existing registries adding first provider)
43-
if len(registry.providers) == 1 and registry.default is None:
44-
registry.default = provider.name
32+
if any(p.name == provider.name for p in registry.providers):
33+
raise ValueError(f"Provider '{provider.name}' already exists")
4534

35+
registry.providers.append(provider)
4636
self.repository.save(registry)
4737

4838
def update(self, original_name: str, updated_provider: ModelProvider) -> None:
@@ -63,15 +53,12 @@ def update(self, original_name: str, updated_provider: ModelProvider) -> None:
6353
if index is None:
6454
raise ValueError(f"Provider '{original_name}' not found")
6555

66-
# Business rule: Name change must not conflict
6756
if updated_provider.name != original_name:
6857
if any(p.name == updated_provider.name for p in registry.providers):
6958
raise ValueError(f"Provider name '{updated_provider.name}' already exists")
7059

71-
# Update
7260
registry.providers[index] = updated_provider
7361

74-
# Business rule: Update default if name changed
7562
if registry.default == original_name and updated_provider.name != original_name:
7663
registry.default = updated_provider.name
7764

@@ -90,14 +77,11 @@ def delete(self, name: str) -> None:
9077
if not any(p.name == name for p in registry.providers):
9178
raise ValueError(f"Provider '{name}' not found")
9279

93-
# Remove provider
9480
registry.providers = [p for p in registry.providers if p.name != name]
9581

96-
# Business rule: Update default if deleted
9782
if registry.default == name:
9883
registry.default = registry.providers[0].name if registry.providers else None
9984

100-
# Business rule: Delete file if no providers left
10185
if registry.providers:
10286
self.repository.save(registry)
10387
else:

tests/cli/conftest.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from data_designer.cli.repositories.model_repository import ModelConfigRegistry, ModelRepository
6+
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
7+
from data_designer.cli.services.model_service import ModelService
8+
from data_designer.cli.services.provider_service import ProviderService
9+
from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider
10+
11+
12+
@pytest.fixture
13+
def stub_model_configs() -> list[ModelConfig]:
14+
return [
15+
ModelConfig(
16+
alias="test-alias-1",
17+
model="test-model-1",
18+
provider="test-provider-1",
19+
inference_parameters=InferenceParameters(
20+
temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4
21+
),
22+
),
23+
ModelConfig(
24+
alias="test-alias-2",
25+
model="test-model-2",
26+
provider="test-provider-1",
27+
inference_parameters=InferenceParameters(
28+
temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4
29+
),
30+
),
31+
]
32+
33+
34+
@pytest.fixture
35+
def stub_new_model_config() -> ModelConfig:
36+
return ModelConfig(
37+
alias="test-alias-3",
38+
model="test-model-3",
39+
provider="test-provider-1",
40+
inference_parameters=InferenceParameters(
41+
temperature=0.7,
42+
top_p=0.9,
43+
max_tokens=2048,
44+
max_parallel_requests=4,
45+
timeout=100,
46+
),
47+
)
48+
49+
50+
@pytest.fixture
51+
def stub_model_providers() -> list[ModelProvider]:
52+
return [
53+
ModelProvider(
54+
name="test-provider-1",
55+
endpoint="https://api.example.com/v1",
56+
provider_type="openai",
57+
api_key="test-api-key",
58+
),
59+
ModelProvider(
60+
name="test-provider-2",
61+
endpoint="https://api.example.com/v2",
62+
provider_type="openai",
63+
api_key="test-api-key-2",
64+
),
65+
]
66+
67+
68+
@pytest.fixture
69+
def stub_new_model_provider() -> ModelProvider:
70+
return ModelProvider(
71+
name="test-provider-3",
72+
endpoint="https://api.example.com/v1",
73+
provider_type="openai",
74+
api_key="test-api-key-1",
75+
)
76+
77+
78+
@pytest.fixture
79+
def stub_model_service(tmp_path: Path, stub_model_configs: list[ModelConfig]) -> ModelService:
80+
repository = ModelRepository(tmp_path)
81+
repository.save(ModelConfigRegistry(model_configs=stub_model_configs))
82+
return ModelService(repository)
83+
84+
85+
@pytest.fixture
86+
def stub_provider_service(tmp_path: Path, stub_model_providers: list[ModelProvider]) -> ProviderService:
87+
repository = ProviderRepository(tmp_path)
88+
repository.save(ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name))
89+
return ProviderService(repository)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from pathlib import Path
2+
3+
from data_designer.cli.constants import MODEL_CONFIGS_FILE_NAME
4+
from data_designer.cli.repositories.model_repository import ModelConfigRegistry, ModelRepository
5+
from data_designer.cli.utils import save_config_file
6+
from data_designer.config.models import ModelConfig
7+
8+
9+
def test_config_file(tmp_path: Path):
10+
repository = ModelRepository(tmp_path)
11+
assert repository.config_file == tmp_path / MODEL_CONFIGS_FILE_NAME
12+
13+
14+
def test_load_does_not_exist(tmp_path: Path):
15+
repository = ModelRepository(tmp_path)
16+
assert repository.load() is None
17+
18+
19+
def test_load_exists(tmp_path: Path, stub_model_configs: list[ModelConfig]):
20+
config_file = tmp_path / MODEL_CONFIGS_FILE_NAME
21+
save_config_file(config_file, ModelConfigRegistry(model_configs=stub_model_configs).model_dump())
22+
repository = ModelRepository(tmp_path)
23+
assert repository.load() is not None
24+
assert repository.load().model_configs == stub_model_configs
25+
26+
27+
def test_save(tmp_path: Path, stub_model_configs: list[ModelConfig]):
28+
repository = ModelRepository(tmp_path)
29+
repository.save(ModelConfigRegistry(model_configs=stub_model_configs))
30+
assert repository.load() is not None
31+
assert repository.load().model_configs == stub_model_configs
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pathlib import Path
2+
3+
from data_designer.cli.constants import MODEL_PROVIDERS_FILE_NAME
4+
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
5+
from data_designer.cli.utils import save_config_file
6+
from data_designer.config.models import ModelProvider
7+
8+
9+
def test_config_file(tmp_path: Path):
10+
repository = ProviderRepository(tmp_path)
11+
assert repository.config_file == tmp_path / MODEL_PROVIDERS_FILE_NAME
12+
13+
14+
def test_load_does_not_exist(tmp_path: Path):
15+
repository = ProviderRepository(tmp_path)
16+
assert repository.load() is None
17+
18+
19+
def test_load_exists(tmp_path: Path, stub_model_providers: list[ModelProvider]):
20+
config_file_path = tmp_path / MODEL_PROVIDERS_FILE_NAME
21+
save_config_file(
22+
config_file_path,
23+
ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name).model_dump(),
24+
)
25+
repository = ProviderRepository(tmp_path)
26+
assert repository.load() is not None
27+
assert repository.load().providers == stub_model_providers
28+
29+
30+
def test_save(tmp_path: Path, stub_model_providers: list[ModelProvider]):
31+
repository = ProviderRepository(tmp_path)
32+
repository.save(ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name))
33+
assert repository.load() is not None
34+
assert repository.load().providers == stub_model_providers

0 commit comments

Comments
 (0)