3
3
import os
4
4
import shutil
5
5
import time
6
+ from copy import deepcopy
6
7
from typing import List , Optional , Type , Union
7
8
8
9
from deepmerge import always_merger as Merger
@@ -106,11 +107,11 @@ def __init__(
106
107
log : Logger ,
107
108
lm_providers : LmProvidersDict ,
108
109
em_providers : EmProvidersDict ,
109
- allowed_providers : Optional [List [str ]],
110
- blocked_providers : Optional [List [str ]],
111
- allowed_models : Optional [List [str ]],
112
- blocked_models : Optional [List [str ]],
113
110
defaults : dict ,
111
+ allowed_providers : Optional [List [str ]] = None ,
112
+ blocked_providers : Optional [List [str ]] = None ,
113
+ allowed_models : Optional [List [str ]] = None ,
114
+ blocked_models : Optional [List [str ]] = None ,
114
115
* args ,
115
116
** kwargs ,
116
117
):
@@ -127,7 +128,13 @@ def __init__(
127
128
self ._allowed_models = allowed_models
128
129
self ._blocked_models = blocked_models
129
130
self ._defaults = defaults
130
- """Provider defaults."""
131
+ """
132
+ Dictionary that maps config keys (e.g. `model_provider_id`, `fields`) to
133
+ user-specified overrides, set by traitlets configuration.
134
+
135
+ Values in this dictionary should never be mutated as they may refer to
136
+ entries in the global `self.settings` dictionary.
137
+ """
131
138
132
139
self ._last_read : Optional [int ] = None
133
140
"""When the server last read the config file. If the file was not
@@ -218,19 +225,22 @@ def _create_default_config(self, default_config):
218
225
self ._write_config (GlobalConfig (** default_config ))
219
226
220
227
def _init_defaults (self ):
221
- field_list = GlobalConfig .__fields__ .keys ()
222
- properties = self .validator .schema .get ("properties" , {})
223
- field_dict = {
224
- field : properties .get (field ).get ("default" ) for field in field_list
228
+ config_keys = GlobalConfig .__fields__ .keys ()
229
+ schema_properties = self .validator .schema .get ("properties" , {})
230
+ default_config = {
231
+ field : schema_properties .get (field ).get ("default" ) for field in config_keys
225
232
}
226
233
if self ._defaults is None :
227
- return field_dict
234
+ return default_config
228
235
229
- for field in field_list :
230
- default_value = self ._defaults .get (field )
236
+ for config_key in config_keys :
237
+ # we call `deepcopy()` here to avoid directly referring to the
238
+ # values in `self._defaults`, as they map to entries in the global
239
+ # `self.settings` dictionary and may be mutated otherwise.
240
+ default_value = deepcopy (self ._defaults .get (config_key ))
231
241
if default_value is not None :
232
- field_dict [ field ] = default_value
233
- return field_dict
242
+ default_config [ config_key ] = default_value
243
+ return default_config
234
244
235
245
def _read_config (self ) -> GlobalConfig :
236
246
"""Returns the user's current configuration as a GlobalConfig object.
@@ -436,16 +446,21 @@ def completions_lm_provider_params(self):
436
446
)
437
447
438
448
def _provider_params (self , key , listing ):
439
- # get generic fields
449
+ # read config
440
450
config = self ._read_config ()
441
- gid = getattr (config , key )
442
- if not gid :
451
+
452
+ # get model ID (without provider ID component) from model universal ID
453
+ # (with provider component).
454
+ model_uid = getattr (config , key )
455
+ if not model_uid :
443
456
return None
457
+ model_id = model_uid .split (":" , 1 )[1 ]
444
458
445
- lid = gid .split (":" , 1 )[1 ]
459
+ # get config fields (e.g. base API URL, etc.)
460
+ fields = config .fields .get (model_uid , {})
446
461
447
462
# get authn fields
448
- _ , Provider = get_em_provider (gid , listing )
463
+ _ , Provider = get_em_provider (model_uid , listing )
449
464
authn_fields = {}
450
465
if Provider .auth_strategy and Provider .auth_strategy .type == "env" :
451
466
keyword_param = (
@@ -456,7 +471,8 @@ def _provider_params(self, key, listing):
456
471
authn_fields [keyword_param ] = config .api_keys [key_name ]
457
472
458
473
return {
459
- "model_id" : lid ,
474
+ "model_id" : model_id ,
475
+ ** fields ,
460
476
** authn_fields ,
461
477
}
462
478
0 commit comments