Skip to content

Commit ebd73b4

Browse files
committed
pull user-defined model configs and providers if available
1 parent f719998 commit ebd73b4

File tree

4 files changed

+159
-114
lines changed

4 files changed

+159
-114
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import logging
5+
import os
6+
from typing import Optional
7+
8+
from data_designer.cli.utils import get_model_config_path, get_model_provider_path, load_config_file
9+
10+
from .models import InferenceParameters, ModelConfig, ModelProvider
11+
from .utils.constants import (
12+
NVIDIA_API_KEY_ENV_VAR_NAME,
13+
NVIDIA_PROVIDER_NAME,
14+
OPENAI_API_KEY_ENV_VAR_NAME,
15+
OPENAI_PROVIDER_NAME,
16+
)
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def get_default_text_alias_inference_parameters() -> InferenceParameters:
22+
return InferenceParameters(
23+
temperature=0.85,
24+
top_p=0.95,
25+
)
26+
27+
28+
def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
29+
return InferenceParameters(
30+
temperature=0.35,
31+
top_p=0.95,
32+
)
33+
34+
35+
def get_default_vision_alias_inference_parameters() -> InferenceParameters:
36+
return InferenceParameters(
37+
temperature=0.85,
38+
top_p=0.95,
39+
)
40+
41+
42+
def get_default_nvidia_model_configs() -> list[ModelConfig]:
43+
if not get_nvidia_api_key():
44+
logger.warning(
45+
f"🔑 {NVIDIA_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://build.nvidia.com' if you want to use the default NVIDIA model configs."
46+
)
47+
return []
48+
return [
49+
ModelConfig(
50+
alias=f"{NVIDIA_PROVIDER_NAME}-text",
51+
model="nvidia/nvidia-nemotron-nano-9b-v2",
52+
provider=NVIDIA_PROVIDER_NAME,
53+
inference_parameters=get_default_text_alias_inference_parameters(),
54+
),
55+
ModelConfig(
56+
alias=f"{NVIDIA_PROVIDER_NAME}-reasoning",
57+
model="openai/gpt-oss-20b",
58+
provider=NVIDIA_PROVIDER_NAME,
59+
inference_parameters=get_default_reasoning_alias_inference_parameters(),
60+
),
61+
ModelConfig(
62+
alias=f"{NVIDIA_PROVIDER_NAME}-vision",
63+
model="nvidia/nemotron-nano-12b-v2-vl",
64+
provider=NVIDIA_PROVIDER_NAME,
65+
inference_parameters=get_default_vision_alias_inference_parameters(),
66+
),
67+
]
68+
69+
70+
def get_default_openai_model_configs() -> list[ModelConfig]:
71+
if not get_openai_api_key():
72+
logger.warning(
73+
f"🔑 {OPENAI_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://platform.openai.com/api-keys' if you want to use the default OpenAI model configs."
74+
)
75+
return []
76+
return [
77+
ModelConfig(
78+
alias=f"{OPENAI_PROVIDER_NAME}-text",
79+
model="gpt-4.1",
80+
provider=OPENAI_PROVIDER_NAME,
81+
inference_parameters=get_default_text_alias_inference_parameters(),
82+
),
83+
ModelConfig(
84+
alias=f"{OPENAI_PROVIDER_NAME}-reasoning",
85+
model="gpt-5",
86+
provider=OPENAI_PROVIDER_NAME,
87+
inference_parameters=get_default_reasoning_alias_inference_parameters(),
88+
),
89+
ModelConfig(
90+
alias=f"{OPENAI_PROVIDER_NAME}-vision",
91+
model="gpt-5",
92+
provider=OPENAI_PROVIDER_NAME,
93+
inference_parameters=get_default_vision_alias_inference_parameters(),
94+
),
95+
]
96+
97+
98+
def get_user_defined_default_model_configs() -> list[ModelConfig]:
99+
pre_defined_model_config_path = get_model_config_path()
100+
if pre_defined_model_config_path.exists():
101+
config_dict = load_config_file(pre_defined_model_config_path)
102+
if "model_configs" in config_dict:
103+
logger.info(f"♻️ Found user-defined default model configs in {str(pre_defined_model_config_path)!r}")
104+
return [ModelConfig.model_validate(mc) for mc in config_dict["model_configs"]]
105+
return []
106+
107+
108+
def get_default_model_configs() -> list[ModelConfig]:
109+
user_defined_default_model_configs = get_user_defined_default_model_configs()
110+
if len(user_defined_default_model_configs) > 0:
111+
return user_defined_default_model_configs
112+
return get_default_nvidia_model_configs() + get_default_openai_model_configs()
113+
114+
115+
def get_default_providers() -> list[ModelProvider]:
116+
user_defined_default_providers = get_user_defined_default_providers()
117+
if len(user_defined_default_providers) > 0:
118+
return user_defined_default_providers
119+
return [
120+
ModelProvider(
121+
name=NVIDIA_PROVIDER_NAME,
122+
endpoint="https://integrate.api.nvidia.com/v1",
123+
api_key=NVIDIA_API_KEY_ENV_VAR_NAME,
124+
),
125+
ModelProvider(
126+
name=OPENAI_PROVIDER_NAME,
127+
endpoint="https://api.openai.com/v1",
128+
api_key=OPENAI_API_KEY_ENV_VAR_NAME,
129+
),
130+
]
131+
132+
133+
def get_user_defined_default_providers() -> list[ModelProvider]:
134+
pre_defined_model_provider_path = get_model_provider_path()
135+
if pre_defined_model_provider_path.exists():
136+
config_dict = load_config_file(pre_defined_model_provider_path)
137+
if "providers" in config_dict:
138+
logger.info(f"♻️ Found user-defined default model providers in {str(pre_defined_model_provider_path)!r}")
139+
return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
140+
return []
141+
142+
143+
def get_nvidia_api_key() -> Optional[str]:
144+
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
145+
146+
147+
def get_openai_api_key() -> Optional[str]:
148+
return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)

src/data_designer/config/models.py

Lines changed: 0 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from abc import ABC, abstractmethod
55
from enum import Enum
66
import logging
7-
import os
87
from pathlib import Path
98
from typing import Any, Generic, List, Optional, TypeVar, Union
109

@@ -19,10 +18,6 @@
1918
MAX_TOP_P,
2019
MIN_TEMPERATURE,
2120
MIN_TOP_P,
22-
NVIDIA_API_KEY_ENV_VAR_NAME,
23-
NVIDIA_PROVIDER_NAME,
24-
OPENAI_API_KEY_ENV_VAR_NAME,
25-
OPENAI_PROVIDER_NAME,
2621
)
2722
from .utils.io_helpers import smart_load_yaml
2823

@@ -233,107 +228,3 @@ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> li
233228
"The list of model configs must be provided under model_configs in the configuration file."
234229
)
235230
return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]
236-
237-
238-
def get_default_text_alias_inference_parameters() -> InferenceParameters:
239-
return InferenceParameters(
240-
temperature=0.85,
241-
top_p=0.95,
242-
)
243-
244-
245-
def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
246-
return InferenceParameters(
247-
temperature=0.35,
248-
top_p=0.95,
249-
)
250-
251-
252-
def get_default_vision_alias_inference_parameters() -> InferenceParameters:
253-
return InferenceParameters(
254-
temperature=0.85,
255-
top_p=0.95,
256-
)
257-
258-
259-
def get_default_nvidia_model_configs() -> list[ModelConfig]:
260-
if not get_nvidia_api_key():
261-
logger.warning(
262-
f"🔑 {NVIDIA_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://build.nvidia.com' if you want to use the default NVIDIA model configs."
263-
)
264-
return []
265-
return [
266-
ModelConfig(
267-
alias=f"{NVIDIA_PROVIDER_NAME}-text",
268-
model="nvidia/nvidia-nemotron-nano-9b-v2",
269-
provider=NVIDIA_PROVIDER_NAME,
270-
inference_parameters=get_default_text_alias_inference_parameters(),
271-
),
272-
ModelConfig(
273-
alias=f"{NVIDIA_PROVIDER_NAME}-reasoning",
274-
model="openai/gpt-oss-20b",
275-
provider=NVIDIA_PROVIDER_NAME,
276-
inference_parameters=get_default_reasoning_alias_inference_parameters(),
277-
),
278-
ModelConfig(
279-
alias=f"{NVIDIA_PROVIDER_NAME}-vision",
280-
model="nvidia/nemotron-nano-12b-v2-vl",
281-
provider=NVIDIA_PROVIDER_NAME,
282-
inference_parameters=get_default_vision_alias_inference_parameters(),
283-
),
284-
]
285-
286-
287-
def get_default_openai_model_configs() -> list[ModelConfig]:
288-
if not get_openai_api_key():
289-
logger.warning(
290-
f"🔑 {OPENAI_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://platform.openai.com/api-keys' if you want to use the default OpenAI model configs."
291-
)
292-
return []
293-
return [
294-
ModelConfig(
295-
alias=f"{OPENAI_PROVIDER_NAME}-text",
296-
model="gpt-4.1",
297-
provider=OPENAI_PROVIDER_NAME,
298-
inference_parameters=get_default_text_alias_inference_parameters(),
299-
),
300-
ModelConfig(
301-
alias=f"{OPENAI_PROVIDER_NAME}-reasoning",
302-
model="gpt-5",
303-
provider=OPENAI_PROVIDER_NAME,
304-
inference_parameters=get_default_reasoning_alias_inference_parameters(),
305-
),
306-
ModelConfig(
307-
alias=f"{OPENAI_PROVIDER_NAME}-vision",
308-
model="gpt-5",
309-
provider=OPENAI_PROVIDER_NAME,
310-
inference_parameters=get_default_vision_alias_inference_parameters(),
311-
),
312-
]
313-
314-
315-
def get_default_model_configs() -> list[ModelConfig]:
316-
return get_default_nvidia_model_configs() + get_default_openai_model_configs()
317-
318-
319-
def get_default_providers() -> list[ModelProvider]:
320-
return [
321-
ModelProvider(
322-
name=NVIDIA_PROVIDER_NAME,
323-
endpoint="https://integrate.api.nvidia.com/v1",
324-
api_key=NVIDIA_API_KEY_ENV_VAR_NAME,
325-
),
326-
ModelProvider(
327-
name=OPENAI_PROVIDER_NAME,
328-
endpoint="https://api.openai.com/v1",
329-
api_key=OPENAI_API_KEY_ENV_VAR_NAME,
330-
),
331-
]
332-
333-
334-
def get_nvidia_api_key() -> Optional[str]:
335-
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
336-
337-
338-
def get_openai_api_key() -> Optional[str]:
339-
return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)

src/data_designer/config/utils/visualization.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
from ..base import ConfigBase
2424
from ..columns import DataDesignerColumnType
25-
from ..models import ModelConfig, ModelProvider, get_nvidia_api_key, get_openai_api_key
25+
from ..default_model_settings import get_nvidia_api_key, get_openai_api_key
26+
from ..models import ModelConfig, ModelProvider
2627
from ..sampler_params import SamplerType
2728
from .code_lang import code_lang_to_syntax_lexer
2829
from .constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME
@@ -296,19 +297,25 @@ def display_model_providers_table(model_providers: list[ModelProvider]) -> None:
296297
api_key = model_provider.api_key
297298
if model_provider.api_key == OPENAI_API_KEY_ENV_VAR_NAME:
298299
if get_openai_api_key() is not None:
299-
api_key = get_openai_api_key()[:1] + "********"
300+
api_key = mask_api_key(get_openai_api_key())
300301
else:
301302
api_key = f"* {OPENAI_API_KEY_ENV_VAR_NAME!r} not set in environment variables * "
302303
elif model_provider.api_key == NVIDIA_API_KEY_ENV_VAR_NAME:
303304
if get_nvidia_api_key() is not None:
304-
api_key = get_nvidia_api_key()[:1] + "********"
305+
api_key = mask_api_key(get_nvidia_api_key())
305306
else:
306307
api_key = f"* {NVIDIA_API_KEY_ENV_VAR_NAME!r} not set in environment variables *"
308+
else:
309+
api_key = mask_api_key(model_provider.api_key)
307310
table_model_providers.add_row(model_provider.name, model_provider.endpoint, api_key)
308311
group = Group(Rule(title="Model Providers"), table_model_providers)
309312
console.print(group)
310313

311314

315+
def mask_api_key(api_key: str) -> str:
316+
return api_key[:1] + "****************"
317+
318+
312319
def convert_to_row_element(elem):
313320
try:
314321
elem = Pretty(json.loads(elem))

src/data_designer/interface/data_designer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88

99
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
1010
from data_designer.config.config_builder import DataDesignerConfigBuilder
11+
from data_designer.config.default_model_settings import get_default_model_configs, get_default_providers
1112
from data_designer.config.interface import DataDesignerInterface
1213
from data_designer.config.models import (
1314
ModelConfig,
1415
ModelProvider,
15-
get_default_model_configs,
16-
get_default_providers,
1716
)
1817
from data_designer.config.preview_results import PreviewResults
1918
from data_designer.config.seed import LocalSeedDatasetReference

0 commit comments

Comments
 (0)