Skip to content

Commit 342bb7b

Browse files
authored
Update model fields immediately on save (#1125)
* add failing test that asserts fields are included in lm_provider_params * fix lm_provider_params prop to include fields * fix bug that writes to `self.settings["model_parameters"]` * add test capturing bug introduced by #421 * pre-commit
1 parent 922712c commit 342bb7b

File tree

2 files changed

+120
-20
lines changed

2 files changed

+120
-20
lines changed

packages/jupyter-ai/jupyter_ai/config_manager.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import shutil
55
import time
6+
from copy import deepcopy
67
from typing import List, Optional, Type, Union
78

89
from deepmerge import always_merger as Merger
@@ -106,11 +107,11 @@ def __init__(
106107
log: Logger,
107108
lm_providers: LmProvidersDict,
108109
em_providers: EmProvidersDict,
109-
allowed_providers: Optional[List[str]],
110-
blocked_providers: Optional[List[str]],
111-
allowed_models: Optional[List[str]],
112-
blocked_models: Optional[List[str]],
113110
defaults: dict,
111+
allowed_providers: Optional[List[str]] = None,
112+
blocked_providers: Optional[List[str]] = None,
113+
allowed_models: Optional[List[str]] = None,
114+
blocked_models: Optional[List[str]] = None,
114115
*args,
115116
**kwargs,
116117
):
@@ -127,7 +128,13 @@ def __init__(
127128
self._allowed_models = allowed_models
128129
self._blocked_models = blocked_models
129130
self._defaults = defaults
130-
"""Provider defaults."""
131+
"""
132+
Dictionary that maps config keys (e.g. `model_provider_id`, `fields`) to
133+
user-specified overrides, set by traitlets configuration.
134+
135+
Values in this dictionary should never be mutated as they may refer to
136+
entries in the global `self.settings` dictionary.
137+
"""
131138

132139
self._last_read: Optional[int] = None
133140
"""When the server last read the config file. If the file was not
@@ -218,19 +225,22 @@ def _create_default_config(self, default_config):
218225
self._write_config(GlobalConfig(**default_config))
219226

220227
def _init_defaults(self):
221-
field_list = GlobalConfig.__fields__.keys()
222-
properties = self.validator.schema.get("properties", {})
223-
field_dict = {
224-
field: properties.get(field).get("default") for field in field_list
228+
config_keys = GlobalConfig.__fields__.keys()
229+
schema_properties = self.validator.schema.get("properties", {})
230+
default_config = {
231+
field: schema_properties.get(field).get("default") for field in config_keys
225232
}
226233
if self._defaults is None:
227-
return field_dict
234+
return default_config
228235

229-
for field in field_list:
230-
default_value = self._defaults.get(field)
236+
for config_key in config_keys:
237+
# we call `deepcopy()` here to avoid directly referring to the
238+
# values in `self._defaults`, as they map to entries in the global
239+
# `self.settings` dictionary and may be mutated otherwise.
240+
default_value = deepcopy(self._defaults.get(config_key))
231241
if default_value is not None:
232-
field_dict[field] = default_value
233-
return field_dict
242+
default_config[config_key] = default_value
243+
return default_config
234244

235245
def _read_config(self) -> GlobalConfig:
236246
"""Returns the user's current configuration as a GlobalConfig object.
@@ -436,16 +446,21 @@ def completions_lm_provider_params(self):
436446
)
437447

438448
def _provider_params(self, key, listing):
439-
# get generic fields
449+
# read config
440450
config = self._read_config()
441-
gid = getattr(config, key)
442-
if not gid:
451+
452+
# get model ID (without provider ID component) from model universal ID
453+
# (with provider component).
454+
model_uid = getattr(config, key)
455+
if not model_uid:
443456
return None
457+
model_id = model_uid.split(":", 1)[1]
444458

445-
lid = gid.split(":", 1)[1]
459+
# get config fields (e.g. base API URL, etc.)
460+
fields = config.fields.get(model_uid, {})
446461

447462
# get authn fields
448-
_, Provider = get_em_provider(gid, listing)
463+
_, Provider = get_em_provider(model_uid, listing)
449464
authn_fields = {}
450465
if Provider.auth_strategy and Provider.auth_strategy.type == "env":
451466
keyword_param = (
@@ -456,7 +471,8 @@ def _provider_params(self, key, listing):
456471
authn_fields[keyword_param] = config.api_keys[key_name]
457472

458473
return {
459-
"model_id": lid,
474+
"model_id": model_id,
475+
**fields,
460476
**authn_fields,
461477
}
462478

packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@ def schema_path(jp_data_dir):
2525
return str(jp_data_dir / "config_schema.json")
2626

2727

28+
@pytest.fixture
29+
def config_file_with_model_fields(jp_data_dir):
30+
"""
31+
Fixture that creates a `config.json` file with the chat model set to
32+
`openai-chat:gpt-4o` and fields for that model. Returns path to the file.
33+
"""
34+
config_data = {
35+
"model_provider_id:": "openai-chat:gpt-4o",
36+
"embeddings_provider_id": None,
37+
"api_keys": {"openai_api_key": "foobar"},
38+
"send_with_shift_enter": False,
39+
"fields": {"openai-chat:gpt-4o": {"openai_api_base": "https://example.com"}},
40+
}
41+
config_path = jp_data_dir / "config.json"
42+
with open(config_path, "w") as file:
43+
json.dump(config_data, file)
44+
return str(config_path)
45+
46+
2847
@pytest.fixture
2948
def common_cm_kwargs(config_path, schema_path):
3049
"""Kwargs that are commonly used when initializing the CM."""
@@ -175,6 +194,28 @@ def configure_to_openai(cm: ConfigManager):
175194
return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS
176195

177196

197+
def configure_with_fields(cm: ConfigManager):
198+
"""
199+
Configures the ConfigManager with fields and API keys.
200+
Returns the expected result of `cm.lm_provider_params`.
201+
"""
202+
req = UpdateConfigRequest(
203+
model_provider_id="openai-chat:gpt-4o",
204+
api_keys={"OPENAI_API_KEY": "foobar"},
205+
fields={
206+
"openai-chat:gpt-4o": {
207+
"openai_api_base": "https://example.com",
208+
}
209+
},
210+
)
211+
cm.update_config(req)
212+
return {
213+
"model_id": "gpt-4o",
214+
"openai_api_key": "foobar",
215+
"openai_api_base": "https://example.com",
216+
}
217+
218+
178219
def test_snapshot_default_config(cm: ConfigManager, snapshot):
179220
config_from_cm: DescribeConfigResponse = cm.get_config()
180221
assert config_from_cm == snapshot(exclude=lambda prop, path: prop == "last_read")
@@ -402,3 +443,46 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
402443
config_desc = cm_with_bad_provider_ids.get_config()
403444
assert config_desc.model_provider_id is None
404445
assert config_desc.embeddings_provider_id is None
446+
447+
448+
def test_config_manager_returns_fields(cm):
449+
"""
450+
Asserts that `ConfigManager.lm_provider_params` returns model fields set by
451+
the user.
452+
"""
453+
expected_model_args = configure_with_fields(cm)
454+
assert cm.lm_provider_params == expected_model_args
455+
456+
457+
def test_config_manager_does_not_write_to_defaults(
458+
config_file_with_model_fields, schema_path
459+
):
460+
"""
461+
Asserts that `ConfigManager` does not write to the `defaults` argument when
462+
the configured chat model differs from the one specified in `defaults`.
463+
"""
464+
from copy import deepcopy
465+
466+
config_path = config_file_with_model_fields
467+
log = logging.getLogger()
468+
lm_providers = get_lm_providers()
469+
em_providers = get_em_providers()
470+
471+
defaults = {
472+
"model_provider_id": None,
473+
"embeddings_provider_id": None,
474+
"api_keys": {},
475+
"fields": {},
476+
}
477+
expected_defaults = deepcopy(defaults)
478+
479+
cm = ConfigManager(
480+
log=log,
481+
lm_providers=lm_providers,
482+
em_providers=em_providers,
483+
config_path=config_path,
484+
schema_path=schema_path,
485+
defaults=defaults,
486+
)
487+
488+
assert defaults == expected_defaults

0 commit comments

Comments
 (0)