1
1
import base64
2
2
import json
3
+ import os
3
4
import re
4
5
import sys
5
6
import warnings
6
7
from typing import Any , Optional
7
- import os
8
- from dotenv import load_dotenv
9
8
10
9
import click
11
10
import litellm
12
11
import traitlets
13
- from typing import Optional
12
+ from dotenv import load_dotenv
14
13
from IPython .core .magic import Magics , line_cell_magic , magics_class
15
14
from IPython .display import HTML , JSON , Markdown , Math
16
15
from jupyter_ai .model_providers .model_list import CHAT_MODELS
33
32
# Load the .env file from the workspace root
34
33
dotenv_path = os .path .join (os .getcwd (), ".env" )
35
34
35
+
36
36
class TextOrMarkdown :
37
37
def __init__ (self , text , markdown ):
38
38
self .text = text
@@ -128,12 +128,14 @@ class AiMagics(Magics):
128
128
# This should only set the "starting set" of aliases
129
129
initial_aliases = traitlets .Dict (
130
130
default_value = {},
131
- value_trait = traitlets .Unicode (),
131
+ value_trait = traitlets .Dict (),
132
132
key_trait = traitlets .Unicode (),
133
133
help = """Aliases for model identifiers.
134
134
135
- Keys define aliases, values define the provider and the model to use.
136
- The values should include identifiers in in the `provider:model` format.
135
+ Keys define aliases, values define a dictionary containing:
136
+ - target: The provider and model to use in the `provider:model` format
137
+ - api_base: Optional base URL for the API endpoint
138
+ - api_key_name: Optional name of the environment variable containing the API key
137
139
""" ,
138
140
config = True ,
139
141
)
@@ -183,8 +185,11 @@ def __init__(self, shell):
183
185
# This is useful for users to know that they can set API keys in the JupyterLab
184
186
# UI, but it is not always required to run the extension.
185
187
if not os .path .isfile (dotenv_path ):
186
- print (f"No `.env` file containing provider API keys found at { dotenv_path } . \
187
- You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI." , file = sys .stderr )
188
+ print (
189
+ f"No `.env` file containing provider API keys found at { dotenv_path } . \
190
+ You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI." ,
191
+ file = sys .stderr ,
192
+ )
188
193
189
194
# TODO: use LiteLLM aliases to provide this
190
195
# https://docs.litellm.ai/docs/completion/model_alias
@@ -240,7 +245,10 @@ def ai(self, line: str, cell: Optional[str] = None) -> Any:
240
245
print (error_msg , file = sys .stderr )
241
246
return
242
247
if not args :
243
- print ("No valid %ai magics arguments given, run `%ai help` for all options." , file = sys .stderr )
248
+ print (
249
+ "No valid %ai magics arguments given, run `%ai help` for all options." ,
250
+ file = sys .stderr ,
251
+ )
244
252
return
245
253
raise e
246
254
@@ -306,21 +314,23 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
306
314
307
315
# Resolve model_id: check if it's in CHAT_MODELS or an alias
308
316
model_id = args .model_id
309
- if model_id not in CHAT_MODELS :
310
- # Check if it's an alias
311
- if model_id in self .aliases :
312
- model_id = self .aliases [model_id ]
313
- else :
314
- error_msg = f"Model ID '{ model_id } ' is not a known model or alias. Run '%ai list' to see available models and aliases."
315
- print (error_msg , file = sys .stderr ) # Log to stderr
316
- return
317
+ # Check if model_id is an alias and get stored configuration
318
+ alias_config = None
319
+ if model_id not in CHAT_MODELS and model_id in self .aliases :
320
+ alias_config = self .aliases [model_id ]
321
+ model_id = alias_config ["target" ]
322
+ # Use stored api_base and api_key_name if not provided in current call
323
+ if not args .api_base and alias_config ["api_base" ]:
324
+ args .api_base = alias_config ["api_base" ]
325
+ if not args .api_key_name and alias_config ["api_key_name" ]:
326
+ args .api_key_name = alias_config ["api_key_name" ]
327
+ elif model_id not in CHAT_MODELS :
328
+ error_msg = f"Model ID '{ model_id } ' is not a known model or alias. Run '%ai list' to see available models and aliases."
329
+ print (error_msg , file = sys .stderr ) # Log to stderr
330
+ return
317
331
try :
318
332
# Prepare litellm completion arguments
319
- completion_args = {
320
- "model" : model_id ,
321
- "messages" : messages ,
322
- "stream" : False
323
- }
333
+ completion_args = {"model" : model_id , "messages" : messages , "stream" : False }
324
334
325
335
# Add api_base if provided
326
336
if args .api_base :
@@ -493,8 +503,12 @@ def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown:
493
503
if args .name in AI_COMMANDS :
494
504
raise ValueError (f"The name { args .name } is reserved for a command" )
495
505
496
- # Store the alias
497
- self .aliases [args .name ] = args .target
506
+ # Store the alias with its configuration
507
+ self .aliases [args .name ] = {
508
+ "target" : args .target ,
509
+ "api_base" : args .api_base ,
510
+ "api_key_name" : args .api_key_name ,
511
+ }
498
512
499
513
output = f"Registered new alias `{ args .name } `"
500
514
return TextOrMarkdown (output , output )
@@ -508,7 +522,7 @@ def handle_version(self, args: VersionArgs) -> str:
508
522
509
523
def handle_list (self , args : ListArgs ):
510
524
"""
511
- Handles `%ai list`.
525
+ Handles `%ai list`.
512
526
- `%ai list` shows all providers by default, and ask the user to run %ai list <provider-name>.
513
527
- `%ai list <provider-name>` shows all models available from one provider. It should also note that the list is not comprehensive, and include a reference to the upstream LiteLLM docs.
514
528
- `%ai list all` should list all models.
@@ -517,12 +531,12 @@ def handle_list(self, args: ListArgs):
517
531
models = CHAT_MODELS
518
532
519
533
# If provider_id is None, only return provider IDs
520
- if getattr (args , ' provider_id' , None ) is None :
534
+ if getattr (args , " provider_id" , None ) is None :
521
535
# Extract unique provider IDs from model IDs
522
536
provider_ids = set ()
523
537
for model in models :
524
- if '/' in model :
525
- provider_ids .add (model .split ('/' )[0 ])
538
+ if "/" in model :
539
+ provider_ids .add (model .split ("/" )[0 ])
526
540
527
541
# Format output for both text and markdown
528
542
text_output = "Available providers\n \n (Run `%ai list <provider_name>` to see models for a specific provider)\n \n "
@@ -533,9 +547,9 @@ def handle_list(self, args: ListArgs):
533
547
markdown_output += f"* `{ provider_id } `\n "
534
548
535
549
return TextOrMarkdown (text_output , markdown_output )
536
-
537
- elif getattr (args , ' provider_id' , None ) == ' all' :
538
- # Otherwise show all models and aliases
550
+
551
+ elif getattr (args , " provider_id" , None ) == " all" :
552
+ # Otherwise show all models and aliases
539
553
text_output = "All available models\n \n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n \n "
540
554
markdown_output = "## All available models \n \n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n \n "
541
555
@@ -547,12 +561,25 @@ def handle_list(self, args: ListArgs):
547
561
if len (self .aliases ) > 0 :
548
562
text_output += "\n Aliases:\n "
549
563
markdown_output += "\n ### Aliases\n \n "
550
- for alias , target in self .aliases .items ():
551
- text_output += f"* { alias } -> { target } \n "
552
- markdown_output += f"* `{ alias } ` -> `{ target } `\n "
564
+ for alias , config in self .aliases .items ():
565
+ text_output += f"* { alias } :\n "
566
+ text_output += f" - target: { config ['target' ]} \n "
567
+ if config ["api_base" ]:
568
+ text_output += f" - api_base: { config ['api_base' ]} \n "
569
+ if config ["api_key_name" ]:
570
+ text_output += f" - api_key_name: { config ['api_key_name' ]} \n "
571
+
572
+ markdown_output += f"* `{ alias } `:\n "
573
+ markdown_output += f" - target: `{ config ['target' ]} `\n "
574
+ if config ["api_base" ]:
575
+ markdown_output += f" - api_base: `{ config ['api_base' ]} `\n "
576
+ if config ["api_key_name" ]:
577
+ markdown_output += (
578
+ f" - api_key_name: `{ config ['api_key_name' ]} `\n "
579
+ )
553
580
554
581
return TextOrMarkdown (text_output , markdown_output )
555
-
582
+
556
583
else :
557
584
# If a specific provider_id is given, filter models by that provider
558
585
provider_id = args .provider_id
@@ -575,10 +602,24 @@ def handle_list(self, args: ListArgs):
575
602
if len (self .aliases ) > 0 :
576
603
text_output += "\n Aliases:\n "
577
604
markdown_output += "\n ### Aliases\n \n "
578
- for alias , target in self .aliases .items ():
579
- if target .startswith (provider_id + "/" ):
580
- text_output += f"* { alias } -> { target } \n "
581
- markdown_output += f"* `{ alias } ` -> `{ target } `\n "
582
-
605
+ for alias , config in self .aliases .items ():
606
+ if config ["target" ].startswith (provider_id + "/" ):
607
+ text_output += f"* { alias } :\n "
608
+ text_output += f" - target: { config ['target' ]} \n "
609
+ if config ["api_base" ]:
610
+ text_output += f" - api_base: { config ['api_base' ]} \n "
611
+ if config ["api_key_name" ]:
612
+ text_output += (
613
+ f" - api_key_name: { config ['api_key_name' ]} \n "
614
+ )
615
+
616
+ markdown_output += f"* `{ alias } `:\n "
617
+ markdown_output += f" - target: `{ config ['target' ]} `\n "
618
+ if config ["api_base" ]:
619
+ markdown_output += f" - api_base: `{ config ['api_base' ]} `\n "
620
+ if config ["api_key_name" ]:
621
+ markdown_output += (
622
+ f" - api_key_name: `{ config ['api_key_name' ]} `\n "
623
+ )
583
624
584
625
return TextOrMarkdown (text_output , markdown_output )
0 commit comments