Skip to content

Commit 17846b2

Browse files
committed
Added tests for default model settings
1 parent ebd73b4 commit 17846b2

File tree

3 files changed

+233
-11
lines changed

3 files changed

+233
-11
lines changed

src/data_designer/config/default_model_settings.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,16 @@ def get_default_model_configs() -> list[ModelConfig]:
112112
return get_default_nvidia_model_configs() + get_default_openai_model_configs()
113113

114114

115+
def get_user_defined_default_providers() -> list[ModelProvider]:
116+
pre_defined_model_provider_path = get_model_provider_path()
117+
if pre_defined_model_provider_path.exists():
118+
config_dict = load_config_file(pre_defined_model_provider_path)
119+
if "providers" in config_dict:
120+
logger.info(f"♻️ Found user-defined default model providers in {str(pre_defined_model_provider_path)!r}")
121+
return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
122+
return []
123+
124+
115125
def get_default_providers() -> list[ModelProvider]:
116126
user_defined_default_providers = get_user_defined_default_providers()
117127
if len(user_defined_default_providers) > 0:
@@ -130,16 +140,6 @@ def get_default_providers() -> list[ModelProvider]:
130140
]
131141

132142

133-
def get_user_defined_default_providers() -> list[ModelProvider]:
134-
pre_defined_model_provider_path = get_model_provider_path()
135-
if pre_defined_model_provider_path.exists():
136-
config_dict = load_config_file(pre_defined_model_provider_path)
137-
if "providers" in config_dict:
138-
logger.info(f"♻️ Found user-defined default model providers in {str(pre_defined_model_provider_path)!r}")
139-
return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
140-
return []
141-
142-
143143
def get_nvidia_api_key() -> Optional[str]:
144144
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
145145

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
from pathlib import Path
2+
from unittest.mock import patch
3+
4+
from data_designer.config.default_model_settings import (
5+
get_default_model_configs,
6+
get_default_nvidia_model_configs,
7+
get_default_openai_model_configs,
8+
get_default_providers,
9+
get_user_defined_default_model_configs,
10+
get_user_defined_default_providers,
11+
)
12+
from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider
13+
14+
15+
@patch("data_designer.config.default_model_settings.get_nvidia_api_key")
16+
def test_get_default_nvidia_model_configs(mock_get_nvidia_api_key):
17+
mock_get_nvidia_api_key.return_value = "nv-some-api-key"
18+
nvidia_model_configs = get_default_nvidia_model_configs()
19+
assert len(nvidia_model_configs) == 3
20+
assert nvidia_model_configs[0].alias == "nvidia-text"
21+
assert nvidia_model_configs[0].model == "nvidia/nvidia-nemotron-nano-9b-v2"
22+
assert nvidia_model_configs[0].provider == "nvidia"
23+
assert nvidia_model_configs[0].inference_parameters is not None
24+
25+
assert nvidia_model_configs[1].alias == "nvidia-reasoning"
26+
assert nvidia_model_configs[1].model == "openai/gpt-oss-20b"
27+
assert nvidia_model_configs[1].provider == "nvidia"
28+
assert nvidia_model_configs[1].inference_parameters is not None
29+
30+
assert nvidia_model_configs[2].alias == "nvidia-vision"
31+
assert nvidia_model_configs[2].model == "nvidia/nemotron-nano-12b-v2-vl"
32+
assert nvidia_model_configs[2].provider == "nvidia"
33+
assert nvidia_model_configs[2].inference_parameters is not None
34+
35+
36+
@patch("data_designer.config.default_model_settings.get_nvidia_api_key")
37+
def test_get_default_nvidia_model_configs_no_api_key(mock_get_nvidia_api_key):
38+
mock_get_nvidia_api_key.return_value = None
39+
nvidia_model_configs = get_default_nvidia_model_configs()
40+
assert len(nvidia_model_configs) == 0
41+
42+
43+
@patch("data_designer.config.default_model_settings.get_openai_api_key")
44+
def test_get_default_openai_model_configs(mock_get_openai_api_key):
45+
mock_get_openai_api_key.return_value = "sk-some-api-key"
46+
openai_model_configs = get_default_openai_model_configs()
47+
assert len(openai_model_configs) == 3
48+
assert openai_model_configs[0].alias == "openai-text"
49+
assert openai_model_configs[0].model == "gpt-4.1"
50+
assert openai_model_configs[0].provider == "openai"
51+
assert openai_model_configs[0].inference_parameters is not None
52+
53+
assert openai_model_configs[1].alias == "openai-reasoning"
54+
assert openai_model_configs[1].model == "gpt-5"
55+
assert openai_model_configs[1].provider == "openai"
56+
assert openai_model_configs[1].inference_parameters is not None
57+
58+
assert openai_model_configs[2].alias == "openai-vision"
59+
assert openai_model_configs[2].model == "gpt-5"
60+
assert openai_model_configs[2].provider == "openai"
61+
assert openai_model_configs[2].inference_parameters is not None
62+
63+
64+
@patch("data_designer.config.default_model_settings.get_openai_api_key")
65+
def test_get_default_openai_model_configs_no_api_key(mock_get_openai_api_key):
66+
mock_get_openai_api_key.return_value = None
67+
openai_model_configs = get_default_openai_model_configs()
68+
assert len(openai_model_configs) == 0
69+
70+
71+
@patch("data_designer.config.default_model_settings.get_model_config_path")
72+
def test_get_user_defined_default_model_configs(mock_get_model_config_path, tmp_path: Path):
73+
model_configs_path = tmp_path / "model_configs.yaml"
74+
mock_get_model_config_path.return_value = model_configs_path
75+
(tmp_path / "model_configs.yaml").write_text(
76+
"""
77+
model_configs:
78+
- alias: test-model-1
79+
model: test/model-id
80+
provider: model-provider
81+
inference_parameters:
82+
temperature: 0.8
83+
top_p: 0.9
84+
- alias: test-model-2
85+
model: test/model-id-2
86+
provider: model-provider-2
87+
inference_parameters:
88+
temperature: 0.8
89+
top_p: 0.9
90+
"""
91+
)
92+
user_defined_model_configs = get_user_defined_default_model_configs()
93+
assert len(user_defined_model_configs) == 2
94+
assert user_defined_model_configs[0].alias == "test-model-1"
95+
assert user_defined_model_configs[0].model == "test/model-id"
96+
assert user_defined_model_configs[0].provider == "model-provider"
97+
assert user_defined_model_configs[0].inference_parameters is not None
98+
assert user_defined_model_configs[1].alias == "test-model-2"
99+
assert user_defined_model_configs[1].model == "test/model-id-2"
100+
assert user_defined_model_configs[1].provider == "model-provider-2"
101+
assert user_defined_model_configs[1].inference_parameters is not None
102+
103+
104+
@patch("data_designer.config.default_model_settings.get_model_config_path")
105+
def test_get_user_defined_default_model_configs_no_user_defined_configs(mock_get_model_config_path, tmp_path: Path):
106+
mock_get_model_config_path.return_value = tmp_path / "model_configs.yaml"
107+
assert len(get_user_defined_default_model_configs()) == 0
108+
109+
110+
@patch("data_designer.config.default_model_settings.get_default_nvidia_model_configs")
111+
@patch("data_designer.config.default_model_settings.get_default_openai_model_configs")
112+
@patch("data_designer.config.default_model_settings.get_user_defined_default_model_configs")
113+
def test_get_default_model_configs_no_user_defined_configs(
114+
mock_get_user_defined_default_model_configs,
115+
mock_get_default_openai_model_configs,
116+
mock_get_default_nvidia_model_configs,
117+
):
118+
mock_get_default_nvidia_model_configs.return_value = [
119+
ModelConfig(
120+
alias="test-model-1",
121+
model="test/model-id",
122+
provider="nvidia",
123+
inference_parameters=InferenceParameters(temperature=0.8, top_p=0.9),
124+
),
125+
]
126+
mock_get_default_openai_model_configs.return_value = [
127+
ModelConfig(
128+
alias="test-model-2",
129+
model="test/model-id-2",
130+
provider="openai",
131+
inference_parameters=InferenceParameters(temperature=0.8, top_p=0.9),
132+
),
133+
]
134+
mock_get_user_defined_default_model_configs.return_value = []
135+
model_configs = get_default_model_configs()
136+
assert len(model_configs) == 2
137+
assert model_configs[0].alias == "test-model-1"
138+
assert model_configs[0].provider == "nvidia"
139+
assert model_configs[1].alias == "test-model-2"
140+
assert model_configs[1].provider == "openai"
141+
142+
143+
@patch("data_designer.config.default_model_settings.get_user_defined_default_model_configs")
144+
def test_get_default_model_configs_with_user_defined_configs(mock_get_user_defined_default_model_configs):
145+
mock_get_user_defined_default_model_configs.return_value = [
146+
ModelConfig(
147+
alias="test-model-1",
148+
model="test/model-id-1",
149+
provider="model-provider",
150+
inference_parameters=InferenceParameters(temperature=0.8, top_p=0.9),
151+
),
152+
]
153+
model_configs = get_default_model_configs()
154+
assert len(model_configs) == 1
155+
assert model_configs[0].alias == "test-model-1"
156+
assert model_configs[0].provider == "model-provider"
157+
158+
159+
@patch("data_designer.config.default_model_settings.get_model_provider_path")
160+
def test_get_user_defined_default_providers(mock_get_model_provider_path, tmp_path: Path):
161+
model_providers_path = tmp_path / "model_providers.yaml"
162+
mock_get_model_provider_path.return_value = model_providers_path
163+
(tmp_path / "model_providers.yaml").write_text(
164+
"""
165+
providers:
166+
- name: test-provider-1
167+
endpoint: https://api.test-provider-1.com/v1
168+
api_key: test-api-key-1
169+
- name: test-provider-2
170+
endpoint: https://api.test-provider-2.com/v1
171+
api_key: test-api-key-2
172+
"""
173+
)
174+
user_defined_providers = get_user_defined_default_providers()
175+
assert len(user_defined_providers) == 2
176+
assert user_defined_providers[0].name == "test-provider-1"
177+
assert user_defined_providers[0].endpoint == "https://api.test-provider-1.com/v1"
178+
assert user_defined_providers[0].api_key == "test-api-key-1"
179+
assert user_defined_providers[1].name == "test-provider-2"
180+
assert user_defined_providers[1].endpoint == "https://api.test-provider-2.com/v1"
181+
assert user_defined_providers[1].api_key == "test-api-key-2"
182+
183+
184+
@patch("data_designer.config.default_model_settings.get_model_provider_path")
185+
def test_get_user_defined_default_providers_no_user_defined_providers(mock_get_model_provider_path, tmp_path: Path):
186+
mock_get_model_provider_path.return_value = tmp_path / "model_providers.yaml"
187+
assert len(get_user_defined_default_providers()) == 0
188+
189+
190+
@patch("data_designer.config.default_model_settings.get_user_defined_default_providers")
191+
def test_get_default_providers_no_user_defined_providers(mock_get_user_defined_default_providers):
192+
mock_get_user_defined_default_providers.return_value = []
193+
default_providers = get_default_providers()
194+
assert len(default_providers) == 2
195+
assert default_providers[0].name == "nvidia"
196+
assert default_providers[0].endpoint == "https://integrate.api.nvidia.com/v1"
197+
assert default_providers[0].api_key == "NVIDIA_API_KEY"
198+
assert default_providers[1].name == "openai"
199+
assert default_providers[1].endpoint == "https://api.openai.com/v1"
200+
assert default_providers[1].api_key == "OPENAI_API_KEY"
201+
202+
203+
@patch("data_designer.config.default_model_settings.get_user_defined_default_providers")
204+
def test_get_default_providers_with_user_defined_providers(mock_get_user_defined_default_providers):
205+
mock_get_user_defined_default_providers.return_value = [
206+
ModelProvider(
207+
name="test-provider-1",
208+
endpoint="https://api.test-provider-1.com/v1",
209+
api_key="test-api-key-1",
210+
),
211+
]
212+
default_providers = get_default_providers()
213+
assert len(default_providers) == 1
214+
assert default_providers[0].name == "test-provider-1"
215+
assert default_providers[0].endpoint == "https://api.test-provider-1.com/v1"
216+
assert default_providers[0].api_key == "test-api-key-1"

tests/config/utils/test_visualization.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from data_designer.config.config_builder import DataDesignerConfigBuilder
1010
from data_designer.config.utils.code_lang import CodeLang
11-
from data_designer.config.utils.visualization import display_sample_record
11+
from data_designer.config.utils.visualization import display_sample_record, mask_api_key
1212
from data_designer.config.validator_params import CodeValidatorParams
1313

1414

@@ -57,3 +57,9 @@ def test_display_sample_record_twice_no_errors(validation_output, config_builder
5757

5858
# If we reach this point without exceptions, the test passes
5959
assert True
60+
61+
62+
def test_mask_api_key():
63+
assert mask_api_key("sk-1234567890") == "s****************"
64+
assert mask_api_key("") == "****************"
65+
assert mask_api_key("nv-some-api-key") == "n****************"

0 commit comments

Comments
 (0)