Skip to content

Commit bddeb4f

Browse files
committed
Add support for multiple default model providers and config
1 parent 74bdf6f commit bddeb4f

File tree

5 files changed

+124
-32
lines changed

5 files changed

+124
-32
lines changed

src/data_designer/config/models.py

Lines changed: 81 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
MIN_TOP_P,
2222
NVIDIA_API_KEY_ENV_VAR_NAME,
2323
NVIDIA_PROVIDER_NAME,
24+
OPENAI_API_KEY_ENV_VAR_NAME,
25+
OPENAI_PROVIDER_NAME,
2426
)
2527
from .utils.io_helpers import smart_load_yaml
2628

@@ -233,38 +235,98 @@ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> li
233235
return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]
234236

235237

238+
def get_default_text_alias_inference_parameters() -> InferenceParameters:
239+
return InferenceParameters(
240+
temperature=0.85,
241+
top_p=0.95,
242+
)
243+
244+
245+
def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
246+
return InferenceParameters(
247+
temperature=0.35,
248+
top_p=0.95,
249+
)
250+
251+
252+
def get_default_vision_alias_inference_parameters() -> InferenceParameters:
253+
return InferenceParameters(
254+
temperature=0.85,
255+
top_p=0.95,
256+
)
257+
258+
236259
def get_default_nvidia_model_configs() -> list[ModelConfig]:
237260
if not get_nvidia_api_key():
238261
logger.warning(
239-
"‼️🔑 'NVIDIA_API_KEY' environment variable is not set. Please set it to your API key from 'build.nvidia.com' if you want to use the default NVIDIA model configs."
262+
f"🔑 {NVIDIA_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://build.nvidia.com' if you want to use the default NVIDIA model configs."
240263
)
264+
return []
241265
return [
242266
ModelConfig(
243-
alias="text",
267+
alias=f"{NVIDIA_PROVIDER_NAME}-text",
244268
model="nvidia/nvidia-nemotron-nano-9b-v2",
245269
provider=NVIDIA_PROVIDER_NAME,
246-
inference_parameters=InferenceParameters(
247-
temperature=0.85,
248-
top_p=0.95,
249-
),
270+
inference_parameters=get_default_text_alias_inference_parameters(),
250271
),
251272
ModelConfig(
252-
alias="reasoning",
253-
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
273+
alias=f"{NVIDIA_PROVIDER_NAME}-reasoning",
274+
model="openai/gpt-oss-20b",
254275
provider=NVIDIA_PROVIDER_NAME,
255-
inference_parameters=InferenceParameters(
256-
temperature=0.35,
257-
top_p=0.95,
258-
),
276+
inference_parameters=get_default_reasoning_alias_inference_parameters(),
259277
),
260278
ModelConfig(
261-
alias="vision",
279+
alias=f"{NVIDIA_PROVIDER_NAME}-vision",
262280
model="nvidia/nemotron-nano-12b-v2-vl",
263281
provider=NVIDIA_PROVIDER_NAME,
264-
inference_parameters=InferenceParameters(
265-
temperature=0.85,
266-
top_p=0.95,
267-
),
282+
inference_parameters=get_default_vision_alias_inference_parameters(),
283+
),
284+
]
285+
286+
287+
def get_default_openai_model_configs() -> list[ModelConfig]:
288+
if not get_openai_api_key():
289+
logger.warning(
290+
f"🔑 {OPENAI_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://platform.openai.com/api-keys' if you want to use the default OpenAI model configs."
291+
)
292+
return []
293+
return [
294+
ModelConfig(
295+
alias=f"{OPENAI_PROVIDER_NAME}-text",
296+
model="gpt-4.1",
297+
provider=OPENAI_PROVIDER_NAME,
298+
inference_parameters=get_default_text_alias_inference_parameters(),
299+
),
300+
ModelConfig(
301+
alias=f"{OPENAI_PROVIDER_NAME}-reasoning",
302+
model="gpt-5",
303+
provider=OPENAI_PROVIDER_NAME,
304+
inference_parameters=get_default_reasoning_alias_inference_parameters(),
305+
),
306+
ModelConfig(
307+
alias=f"{OPENAI_PROVIDER_NAME}-vision",
308+
model="gpt-5",
309+
provider=OPENAI_PROVIDER_NAME,
310+
inference_parameters=get_default_vision_alias_inference_parameters(),
311+
),
312+
]
313+
314+
315+
def get_default_model_configs() -> list[ModelConfig]:
316+
return get_default_nvidia_model_configs() + get_default_openai_model_configs()
317+
318+
319+
def get_default_providers() -> list[ModelProvider]:
320+
return [
321+
ModelProvider(
322+
name=NVIDIA_PROVIDER_NAME,
323+
endpoint="https://integrate.api.nvidia.com/v1",
324+
api_key=NVIDIA_API_KEY_ENV_VAR_NAME,
325+
),
326+
ModelProvider(
327+
name=OPENAI_PROVIDER_NAME,
328+
endpoint="https://api.openai.com/v1",
329+
api_key=OPENAI_API_KEY_ENV_VAR_NAME,
268330
),
269331
]
270332

@@ -273,9 +335,5 @@ def get_nvidia_api_key() -> Optional[str]:
273335
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
274336

275337

276-
def get_default_nvidia_model_provider() -> ModelProvider:
277-
return ModelProvider(
278-
name=NVIDIA_PROVIDER_NAME,
279-
endpoint="https://integrate.api.nvidia.com/v1",
280-
api_key=NVIDIA_API_KEY_ENV_VAR_NAME,
281-
)
338+
def get_openai_api_key() -> Optional[str]:
339+
return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)

src/data_designer/config/utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,6 @@ class NordColor(Enum):
260260

261261
NVIDIA_PROVIDER_NAME = "nvidia"
262262
NVIDIA_API_KEY_ENV_VAR_NAME = "NVIDIA_API_KEY"
263+
264+
OPENAI_PROVIDER_NAME = "openai"
265+
OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"

src/data_designer/config/utils/visualization.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222

2323
from ..base import ConfigBase
2424
from ..columns import DataDesignerColumnType
25-
from ..models import ModelConfig, ModelProvider
25+
from ..models import ModelConfig, ModelProvider, get_nvidia_api_key, get_openai_api_key
2626
from ..sampler_params import SamplerType
2727
from .code_lang import code_lang_to_syntax_lexer
28+
from .constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME
2829
from .errors import DatasetSampleDisplayError
2930

3031
if TYPE_CHECKING:
@@ -274,7 +275,15 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
274275
str(model_config.inference_parameters.temperature),
275276
str(model_config.inference_parameters.top_p),
276277
)
277-
group = Group(Rule(title="Model Configs"), table_model_configs)
278+
group_args: list = [Rule(title="Model Configs"), table_model_configs]
279+
if len(model_configs) == 0:
280+
subtitle = Text(
281+
"‼️ No model configs found. Please provide at least one model config to the config builder",
282+
style="dim",
283+
justify="center",
284+
)
285+
group_args.insert(1, subtitle)
286+
group = Group(*group_args)
278287
console.print(group)
279288

280289

@@ -284,7 +293,18 @@ def display_model_providers_table(model_providers: list[ModelProvider]) -> None:
284293
table_model_providers.add_column("Endpoint")
285294
table_model_providers.add_column("API Key")
286295
for model_provider in model_providers:
287-
table_model_providers.add_row(model_provider.name, model_provider.endpoint, model_provider.api_key)
296+
api_key = model_provider.api_key
297+
if model_provider.api_key == OPENAI_API_KEY_ENV_VAR_NAME:
298+
if get_openai_api_key() is not None:
299+
api_key = get_openai_api_key()[:1] + "********"
300+
else:
301+
api_key = f"* {OPENAI_API_KEY_ENV_VAR_NAME!r} not set in environment variables * "
302+
elif model_provider.api_key == NVIDIA_API_KEY_ENV_VAR_NAME:
303+
if get_nvidia_api_key() is not None:
304+
api_key = get_nvidia_api_key()[:1] + "********"
305+
else:
306+
api_key = f"* {NVIDIA_API_KEY_ENV_VAR_NAME!r} not set in environment variables *"
307+
table_model_providers.add_row(model_provider.name, model_provider.endpoint, api_key)
288308
group = Group(Rule(title="Model Providers"), table_model_providers)
289309
console.print(group)
290310

src/data_designer/engine/models/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def get_model_provider(self, *, model_alias: str) -> ModelProvider:
7171
def run_health_check(self) -> None:
7272
logger.info("🩺 Running health checks for models...")
7373
for model in self._models.values():
74-
logger.info(f" |-- 👀 Checking '{model.model_name}'...")
74+
logger.info(
75+
f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
76+
)
7577
try:
7678
model.generate(
7779
prompt="Hello!",

src/data_designer/interface/data_designer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
from data_designer.config.models import (
1313
ModelConfig,
1414
ModelProvider,
15-
get_default_nvidia_model_configs,
16-
get_default_nvidia_model_provider,
15+
get_default_model_configs,
16+
get_default_providers,
1717
)
1818
from data_designer.config.preview_results import PreviewResults
1919
from data_designer.config.seed import LocalSeedDatasetReference
20-
from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS
20+
from data_designer.config.utils.constants import (
21+
DEFAULT_NUM_RECORDS,
22+
NVIDIA_API_KEY_ENV_VAR_NAME,
23+
OPENAI_API_KEY_ENV_VAR_NAME,
24+
)
2125
from data_designer.config.utils.info import InterfaceInfo
2226
from data_designer.config.utils.io_helpers import write_seed_dataset
2327
from data_designer.engine.analysis.dataset_profiler import (
@@ -227,10 +231,15 @@ def preview(
227231
)
228232

229233
def get_default_model_configs(self) -> list[ModelConfig]:
230-
return get_default_nvidia_model_configs()
234+
model_configs = get_default_model_configs()
235+
if len(model_configs) == 0:
236+
logger.warning(
237+
f"‼️ Neither {NVIDIA_API_KEY_ENV_VAR_NAME!r} nor {OPENAI_API_KEY_ENV_VAR_NAME!r} environment variables are set. Please set at least one of them if you want to use the default model configs."
238+
)
239+
return model_configs
231240

232241
def get_default_model_providers(self) -> list[ModelProvider]:
233-
return [get_default_nvidia_model_provider()]
242+
return get_default_providers()
234243

235244
def set_buffer_size(self, buffer_size: int) -> None:
236245
"""Set the buffer size for dataset generation.

0 commit comments

Comments
 (0)