3
3
import os
4
4
import time
5
5
from copy import deepcopy
6
- from typing import Optional , Union
6
+ from typing import Any , Optional , Union
7
7
8
8
from deepmerge import always_merger
9
9
from jupyter_ai_magics .utils import (
@@ -175,69 +175,10 @@ def _process_existing_config(self):
175
175
with open (self .config_path , encoding = "utf-8" ) as f :
176
176
existing_config = json .loads (f .read ())
177
177
config = JaiConfig (** existing_config )
178
- validated_config = self ._validate_model_ids (config )
179
178
180
179
# re-write to the file to validate the config and apply any
181
180
# updates to the config file immediately
182
- self ._write_config (validated_config )
183
-
184
- def _validate_model_ids (self , config ):
185
- lm_provider_keys = ["model_provider_id" , "completions_model_provider_id" ]
186
- em_provider_keys = ["embeddings_provider_id" ]
187
- clm_provider_keys = ["completions_model_provider_id" ]
188
-
189
- # if the currently selected language or embedding model are
190
- # forbidden, set them to `None` and log a warning.
191
- for lm_key in lm_provider_keys :
192
- lm_id = getattr (config , lm_key )
193
- if lm_id is not None and not self ._validate_model (lm_id , raise_exc = False ):
194
- self .log .warning (
195
- f"Language model { lm_id } is forbidden by current allow/blocklists. Setting to None."
196
- )
197
- setattr (config , lm_key , None )
198
- for em_key in em_provider_keys :
199
- em_id = getattr (config , em_key )
200
- if em_id is not None and not self ._validate_model (em_id , raise_exc = False ):
201
- self .log .warning (
202
- f"Embedding model { em_id } is forbidden by current allow/blocklists. Setting to None."
203
- )
204
- setattr (config , em_key , None )
205
- for clm_key in clm_provider_keys :
206
- clm_id = getattr (config , clm_key )
207
- if clm_id is not None and not self ._validate_model (clm_id , raise_exc = False ):
208
- self .log .warning (
209
- f"Completion model { clm_id } is forbidden by current allow/blocklists. Setting to None."
210
- )
211
- setattr (config , clm_key , None )
212
-
213
- # if the currently selected language or embedding model ids are
214
- # not associated with models, set them to `None` and log a warning.
215
- for lm_key in lm_provider_keys :
216
- lm_id = getattr (config , lm_key )
217
- if lm_id is not None and not get_lm_provider (lm_id , self ._lm_providers )[1 ]:
218
- self .log .warning (
219
- f"No language model is associated with '{ lm_id } '. Setting to None."
220
- )
221
- setattr (config , lm_key , None )
222
- for em_key in em_provider_keys :
223
- em_id = getattr (config , em_key )
224
- if em_id is not None and not get_em_provider (em_id , self ._em_providers )[1 ]:
225
- self .log .warning (
226
- f"No embedding model is associated with '{ em_id } '. Setting to None."
227
- )
228
- setattr (config , em_key , None )
229
- for clm_key in clm_provider_keys :
230
- clm_id = getattr (config , clm_key )
231
- if (
232
- clm_id is not None
233
- and not get_lm_provider (clm_id , self ._lm_providers )[1 ]
234
- ):
235
- self .log .warning (
236
- f"No completion model is associated with '{ clm_id } '. Setting to None."
237
- )
238
- setattr (config , clm_key , None )
239
-
240
- return config
181
+ self ._write_config (config )
241
182
242
183
def _read_config (self ) -> JaiConfig :
243
184
"""
@@ -268,78 +209,79 @@ def _validate_config(self, config: JaiConfig):
268
209
user has specified authentication for all configured models that require
269
210
it.
270
211
"""
212
+ # TODO: re-implement this w/ liteLLM
271
213
# validate language model config
272
- if config .model_provider_id :
273
- _ , lm_provider = get_lm_provider (
274
- config .model_provider_id , self ._lm_providers
275
- )
214
+ # if config.model_provider_id:
215
+ # _, lm_provider = get_lm_provider(
216
+ # config.model_provider_id, self._lm_providers
217
+ # )
276
218
277
- # verify model is declared by some provider
278
- if not lm_provider :
279
- raise ValueError (
280
- f"No language model is associated with '{ config .model_provider_id } '."
281
- )
219
+ # # verify model is declared by some provider
220
+ # if not lm_provider:
221
+ # raise ValueError(
222
+ # f"No language model is associated with '{config.model_provider_id}'."
223
+ # )
282
224
283
- # verify model is not blocked
284
- self ._validate_model (config .model_provider_id )
225
+ # # verify model is not blocked
226
+ # self._validate_model(config.model_provider_id)
285
227
286
- # verify model is authenticated
287
- _validate_provider_authn (config , lm_provider )
228
+ # # verify model is authenticated
229
+ # _validate_provider_authn(config, lm_provider)
288
230
289
- # verify fields exist for this model if needed
290
- if lm_provider .fields and config .model_provider_id not in config .fields :
291
- config .fields [config .model_provider_id ] = {}
231
+ # # verify fields exist for this model if needed
232
+ # if lm_provider.fields and config.model_provider_id not in config.fields:
233
+ # config.fields[config.model_provider_id] = {}
292
234
293
235
# validate completions model config
294
- if config .completions_model_provider_id :
295
- _ , completions_provider = get_lm_provider (
296
- config .completions_model_provider_id , self ._lm_providers
297
- )
298
-
299
- # verify model is declared by some provider
300
- if not completions_provider :
301
- raise ValueError (
302
- f"No language model is associated with '{ config .completions_model_provider_id } '."
303
- )
304
-
305
- # verify model is not blocked
306
- self ._validate_model (config .completions_model_provider_id )
307
-
308
- # verify model is authenticated
309
- _validate_provider_authn (config , completions_provider )
310
-
311
- # verify completions fields exist for this model if needed
312
- if (
313
- completions_provider .fields
314
- and config .completions_model_provider_id
315
- not in config .completions_fields
316
- ):
317
- config .completions_fields [config .completions_model_provider_id ] = {}
318
-
319
- # validate embedding model config
320
- if config .embeddings_provider_id :
321
- _ , em_provider = get_em_provider (
322
- config .embeddings_provider_id , self ._em_providers
323
- )
324
-
325
- # verify model is declared by some provider
326
- if not em_provider :
327
- raise ValueError (
328
- f"No embedding model is associated with '{ config .embeddings_provider_id } '."
329
- )
330
-
331
- # verify model is not blocked
332
- self ._validate_model (config .embeddings_provider_id )
333
-
334
- # verify model is authenticated
335
- _validate_provider_authn (config , em_provider )
336
-
337
- # verify embedding fields exist for this model if needed
338
- if (
339
- em_provider .fields
340
- and config .embeddings_provider_id not in config .embeddings_fields
341
- ):
342
- config .embeddings_fields [config .embeddings_provider_id ] = {}
236
+ # if config.completions_model_provider_id:
237
+ # _, completions_provider = get_lm_provider(
238
+ # config.completions_model_provider_id, self._lm_providers
239
+ # )
240
+
241
+ # # verify model is declared by some provider
242
+ # if not completions_provider:
243
+ # raise ValueError(
244
+ # f"No language model is associated with '{config.completions_model_provider_id}'."
245
+ # )
246
+
247
+ # # verify model is not blocked
248
+ # self._validate_model(config.completions_model_provider_id)
249
+
250
+ # # verify model is authenticated
251
+ # _validate_provider_authn(config, completions_provider)
252
+
253
+ # # verify completions fields exist for this model if needed
254
+ # if (
255
+ # completions_provider.fields
256
+ # and config.completions_model_provider_id
257
+ # not in config.completions_fields
258
+ # ):
259
+ # config.completions_fields[config.completions_model_provider_id] = {}
260
+
261
+ # # validate embedding model config
262
+ # if config.embeddings_provider_id:
263
+ # _, em_provider = get_em_provider(
264
+ # config.embeddings_provider_id, self._em_providers
265
+ # )
266
+
267
+ # # verify model is declared by some provider
268
+ # if not em_provider:
269
+ # raise ValueError(
270
+ # f"No embedding model is associated with '{config.embeddings_provider_id}'."
271
+ # )
272
+
273
+ # # verify model is not blocked
274
+ # self._validate_model(config.embeddings_provider_id)
275
+
276
+ # # verify model is authenticated
277
+ # _validate_provider_authn(config, em_provider)
278
+
279
+ # # verify embedding fields exist for this model if needed
280
+ # if (
281
+ # em_provider.fields
282
+ # and config.embeddings_provider_id not in config.embeddings_fields
283
+ # ):
284
+ # config.embeddings_fields[config.embeddings_provider_id] = {}
343
285
344
286
def _validate_model (self , model_id : str , raise_exc = True ):
345
287
"""
@@ -449,23 +391,30 @@ def get_config(self):
449
391
)
450
392
451
393
@property
452
- def lm_gid (self ):
394
+ def chat_model (self ) -> str | None :
395
+ """
396
+ Returns the model ID of the chat model from AI settings, if any.
397
+ """
453
398
config = self ._read_config ()
454
399
return config .model_provider_id
455
400
456
401
@property
457
- def em_gid (self ):
458
- config = self ._read_config ()
459
- return config .embeddings_provider_id
402
+ def chat_model_params (self ) -> dict [str , Any ]:
403
+ return self ._provider_params ("model_provider_id" , self ._lm_providers )
460
404
461
405
@property
462
- def lm_provider (self ):
463
- return self ._get_provider ("model_provider_id" , self ._lm_providers )
406
+ def embedding_model (self ) -> str | None :
407
+ """
408
+ Returns the model ID of the embedding model from AI settings, if any.
409
+ """
410
+ config = self ._read_config ()
411
+ return config .embeddings_provider_id
464
412
465
413
@property
466
- def em_provider (self ):
467
- return self ._get_provider ("embeddings_provider_id" , self ._em_providers )
414
+ def embedding_model_params (self ) -> dict [ str , Any ] :
415
+ return self ._provider_params ("embeddings_provider_id" , self ._em_providers )
468
416
417
+ # TODO: use LiteLLM in completions
469
418
@property
470
419
def completions_lm_provider (self ):
471
420
return self ._get_provider ("completions_model_provider_id" , self ._lm_providers )
@@ -479,14 +428,7 @@ def _get_provider(self, key, listing):
479
428
_ , Provider = get_lm_provider (gid , listing )
480
429
return Provider
481
430
482
- @property
483
- def lm_provider_params (self ):
484
- return self ._provider_params ("model_provider_id" , self ._lm_providers )
485
-
486
- @property
487
- def em_provider_params (self ):
488
- return self ._provider_params ("embeddings_provider_id" , self ._em_providers )
489
-
431
+ # TODO: use LiteLLM in completions
490
432
@property
491
433
def completions_lm_provider_params (self ):
492
434
return self ._provider_params (
0 commit comments