33
44from abc import ABC , abstractmethod
55from enum import Enum
6+ import logging
7+ import os
68from pathlib import Path
79from typing import Any , Generic , List , Optional , TypeVar , Union
810
1214
1315from .base import ConfigBase
1416from .errors import InvalidConfigError
15- from .utils .constants import MAX_TEMPERATURE , MAX_TOP_P , MIN_TEMPERATURE , MIN_TOP_P
17+ from .utils .constants import (
18+ MAX_TEMPERATURE ,
19+ MAX_TOP_P ,
20+ MIN_TEMPERATURE ,
21+ MIN_TOP_P ,
22+ NVIDIA_API_KEY_ENV_VAR_NAME ,
23+ NVIDIA_PROVIDER_NAME ,
24+ OPENAI_API_KEY_ENV_VAR_NAME ,
25+ OPENAI_PROVIDER_NAME ,
26+ )
1627from .utils .io_helpers import smart_load_yaml
1728
29+ logger = logging .getLogger (__name__ )
30+
1831
1932class Modality (str , Enum ):
2033 IMAGE = "image"
@@ -204,9 +217,14 @@ class ModelConfig(ConfigBase):
204217 provider : Optional [str ] = None
205218
206219
207- def load_model_configs (model_configs : Union [list [ModelConfig ], str , Path , None ]) -> list [ModelConfig ]:
208- if model_configs is None :
209- return []
220+ class ModelProvider (ConfigBase ):
221+ name : str
222+ endpoint : str
223+ provider_type : str = "openai"
224+ api_key : str | None = None
225+
226+
227+ def load_model_configs (model_configs : Union [list [ModelConfig ], str , Path ]) -> list [ModelConfig ]:
210228 if isinstance (model_configs , list ) and all (isinstance (mc , ModelConfig ) for mc in model_configs ):
211229 return model_configs
212230 json_config = smart_load_yaml (model_configs )
@@ -215,3 +233,107 @@ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None])
215233 "The list of model configs must be provided under model_configs in the configuration file."
216234 )
217235 return [ModelConfig .model_validate (mc ) for mc in json_config ["model_configs" ]]
236+
237+
238+ def get_default_text_alias_inference_parameters () -> InferenceParameters :
239+ return InferenceParameters (
240+ temperature = 0.85 ,
241+ top_p = 0.95 ,
242+ )
243+
244+
245+ def get_default_reasoning_alias_inference_parameters () -> InferenceParameters :
246+ return InferenceParameters (
247+ temperature = 0.35 ,
248+ top_p = 0.95 ,
249+ )
250+
251+
252+ def get_default_vision_alias_inference_parameters () -> InferenceParameters :
253+ return InferenceParameters (
254+ temperature = 0.85 ,
255+ top_p = 0.95 ,
256+ )
257+
258+
259+ def get_default_nvidia_model_configs () -> list [ModelConfig ]:
260+ if not get_nvidia_api_key ():
261+ logger .warning (
262+ f"🔑 { NVIDIA_API_KEY_ENV_VAR_NAME !r} environment variable is not set. Please set it to your API key from 'https://build.nvidia.com' if you want to use the default NVIDIA model configs."
263+ )
264+ return []
265+ return [
266+ ModelConfig (
267+ alias = f"{ NVIDIA_PROVIDER_NAME } -text" ,
268+ model = "nvidia/nvidia-nemotron-nano-9b-v2" ,
269+ provider = NVIDIA_PROVIDER_NAME ,
270+ inference_parameters = get_default_text_alias_inference_parameters (),
271+ ),
272+ ModelConfig (
273+ alias = f"{ NVIDIA_PROVIDER_NAME } -reasoning" ,
274+ model = "openai/gpt-oss-20b" ,
275+ provider = NVIDIA_PROVIDER_NAME ,
276+ inference_parameters = get_default_reasoning_alias_inference_parameters (),
277+ ),
278+ ModelConfig (
279+ alias = f"{ NVIDIA_PROVIDER_NAME } -vision" ,
280+ model = "nvidia/nemotron-nano-12b-v2-vl" ,
281+ provider = NVIDIA_PROVIDER_NAME ,
282+ inference_parameters = get_default_vision_alias_inference_parameters (),
283+ ),
284+ ]
285+
286+
287+ def get_default_openai_model_configs () -> list [ModelConfig ]:
288+ if not get_openai_api_key ():
289+ logger .warning (
290+ f"🔑 { OPENAI_API_KEY_ENV_VAR_NAME !r} environment variable is not set. Please set it to your API key from 'https://platform.openai.com/api-keys' if you want to use the default OpenAI model configs."
291+ )
292+ return []
293+ return [
294+ ModelConfig (
295+ alias = f"{ OPENAI_PROVIDER_NAME } -text" ,
296+ model = "gpt-4.1" ,
297+ provider = OPENAI_PROVIDER_NAME ,
298+ inference_parameters = get_default_text_alias_inference_parameters (),
299+ ),
300+ ModelConfig (
301+ alias = f"{ OPENAI_PROVIDER_NAME } -reasoning" ,
302+ model = "gpt-5" ,
303+ provider = OPENAI_PROVIDER_NAME ,
304+ inference_parameters = get_default_reasoning_alias_inference_parameters (),
305+ ),
306+ ModelConfig (
307+ alias = f"{ OPENAI_PROVIDER_NAME } -vision" ,
308+ model = "gpt-5" ,
309+ provider = OPENAI_PROVIDER_NAME ,
310+ inference_parameters = get_default_vision_alias_inference_parameters (),
311+ ),
312+ ]
313+
314+
315+ def get_default_model_configs () -> list [ModelConfig ]:
316+ return get_default_nvidia_model_configs () + get_default_openai_model_configs ()
317+
318+
319+ def get_default_providers () -> list [ModelProvider ]:
320+ return [
321+ ModelProvider (
322+ name = NVIDIA_PROVIDER_NAME ,
323+ endpoint = "https://integrate.api.nvidia.com/v1" ,
324+ api_key = NVIDIA_API_KEY_ENV_VAR_NAME ,
325+ ),
326+ ModelProvider (
327+ name = OPENAI_PROVIDER_NAME ,
328+ endpoint = "https://api.openai.com/v1" ,
329+ api_key = OPENAI_API_KEY_ENV_VAR_NAME ,
330+ ),
331+ ]
332+
333+
334+ def get_nvidia_api_key () -> Optional [str ]:
335+ return os .getenv (NVIDIA_API_KEY_ENV_VAR_NAME )
336+
337+
338+ def get_openai_api_key () -> Optional [str ]:
339+ return os .getenv (OPENAI_API_KEY_ENV_VAR_NAME )
0 commit comments