|
27 | 27 | logger = logging.getLogger(__name__) |
28 | 28 |
|
29 | 29 |
|
30 | | -def get_default_text_alias_inference_parameters() -> ChatCompletionInferenceParams: |
31 | | - return ChatCompletionInferenceParams( |
32 | | - temperature=0.85, |
33 | | - top_p=0.95, |
34 | | - ) |
35 | | - |
36 | | - |
37 | | -def get_default_reasoning_alias_inference_parameters() -> ChatCompletionInferenceParams: |
38 | | - return ChatCompletionInferenceParams( |
39 | | - temperature=0.35, |
40 | | - top_p=0.95, |
41 | | - ) |
42 | | - |
43 | | - |
44 | | -def get_default_vision_alias_inference_parameters() -> ChatCompletionInferenceParams: |
45 | | - return ChatCompletionInferenceParams( |
46 | | - temperature=0.85, |
47 | | - top_p=0.95, |
48 | | - ) |
49 | | - |
50 | | - |
51 | | -def get_default_embedding_alias_inference_parameters(provider: str) -> EmbeddingInferenceParams: |
52 | | - args = dict(encoding_format="float") |
53 | | - if provider == "nvidia": |
54 | | - args["extra_body"] = {"input_type": "query"} |
55 | | - return EmbeddingInferenceParams(**args) |
56 | | - |
57 | | - |
58 | 30 | def get_default_inference_parameters( |
59 | | - model_alias: Literal["text", "reasoning", "vision", "embedding"], provider: str |
| 31 | + model_alias: Literal["text", "reasoning", "vision", "embedding"], |
| 32 | + inference_parameters: dict[str, Any], |
60 | 33 | ) -> InferenceParamsT: |
61 | 34 | if model_alias == "reasoning": |
62 | | - return get_default_reasoning_alias_inference_parameters() |
| 35 | + return ChatCompletionInferenceParams(**inference_parameters) |
63 | 36 | elif model_alias == "vision": |
64 | | - return get_default_vision_alias_inference_parameters() |
| 37 | + return ChatCompletionInferenceParams(**inference_parameters) |
65 | 38 | elif model_alias == "embedding": |
66 | | - return get_default_embedding_alias_inference_parameters(provider) |
| 39 | + return EmbeddingInferenceParams(**inference_parameters) |
67 | 40 | else: |
68 | | - return get_default_text_alias_inference_parameters() |
| 41 | + return ChatCompletionInferenceParams(**inference_parameters) |
69 | 42 |
|
70 | 43 |
|
71 | 44 | def get_builtin_model_configs() -> list[ModelConfig]: |
72 | 45 | model_configs = [] |
73 | 46 | for provider, model_alias_map in PREDEFINED_PROVIDERS_MODEL_MAP.items(): |
74 | | - for model_alias, model_id in model_alias_map.items(): |
| 47 | + for model_alias, settings in model_alias_map.items(): |
75 | 48 | model_configs.append( |
76 | 49 | ModelConfig( |
77 | 50 | alias=f"{provider}-{model_alias}", |
78 | | - model=model_id, |
| 51 | + model=settings["model"], |
79 | 52 | provider=provider, |
80 | | - inference_parameters=get_default_inference_parameters(model_alias, provider), |
| 53 | + inference_parameters=get_default_inference_parameters( |
| 54 | + model_alias, settings["inference_parameters"] |
| 55 | + ), |
81 | 56 | ) |
82 | 57 | ) |
83 | 58 | return model_configs |
|
0 commit comments