Skip to content

Commit c4859ae

Browse files
authored
[magics] Add options to include the API url & key with alias (#1477)
* [magics] Enhances the `alias` option to include the base API url and key * update tests
1 parent ca34ade commit c4859ae

File tree

3 files changed

+136
-49
lines changed

3 files changed

+136
-49
lines changed

packages/jupyter-ai-magics/jupyter_ai_magics/magics.py

Lines changed: 81 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import base64
22
import json
3+
import os
34
import re
45
import sys
56
import warnings
67
from typing import Any, Optional
7-
import os
8-
from dotenv import load_dotenv
98

109
import click
1110
import litellm
1211
import traitlets
13-
from typing import Optional
12+
from dotenv import load_dotenv
1413
from IPython.core.magic import Magics, line_cell_magic, magics_class
1514
from IPython.display import HTML, JSON, Markdown, Math
1615
from jupyter_ai.model_providers.model_list import CHAT_MODELS
@@ -33,6 +32,7 @@
3332
# Load the .env file from the workspace root
3433
dotenv_path = os.path.join(os.getcwd(), ".env")
3534

35+
3636
class TextOrMarkdown:
3737
def __init__(self, text, markdown):
3838
self.text = text
@@ -128,12 +128,14 @@ class AiMagics(Magics):
128128
# This should only set the "starting set" of aliases
129129
initial_aliases = traitlets.Dict(
130130
default_value={},
131-
value_trait=traitlets.Unicode(),
131+
value_trait=traitlets.Dict(),
132132
key_trait=traitlets.Unicode(),
133133
help="""Aliases for model identifiers.
134134
135-
Keys define aliases, values define the provider and the model to use.
136-
The values should include identifiers in in the `provider:model` format.
135+
Keys define aliases, values define a dictionary containing:
136+
- target: The provider and model to use in the `provider:model` format
137+
- api_base: Optional base URL for the API endpoint
138+
- api_key_name: Optional name of the environment variable containing the API key
137139
""",
138140
config=True,
139141
)
@@ -183,8 +185,11 @@ def __init__(self, shell):
183185
# This is useful for users to know that they can set API keys in the JupyterLab
184186
# UI, but it is not always required to run the extension.
185187
if not os.path.isfile(dotenv_path):
186-
print(f"No `.env` file containing provider API keys found at {dotenv_path}. \
187-
You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI.", file=sys.stderr)
188+
print(
189+
f"No `.env` file containing provider API keys found at {dotenv_path}. \
190+
You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI.",
191+
file=sys.stderr,
192+
)
188193

189194
# TODO: use LiteLLM aliases to provide this
190195
# https://docs.litellm.ai/docs/completion/model_alias
@@ -240,7 +245,10 @@ def ai(self, line: str, cell: Optional[str] = None) -> Any:
240245
print(error_msg, file=sys.stderr)
241246
return
242247
if not args:
243-
print("No valid %ai magics arguments given, run `%ai help` for all options.", file=sys.stderr)
248+
print(
249+
"No valid %ai magics arguments given, run `%ai help` for all options.",
250+
file=sys.stderr,
251+
)
244252
return
245253
raise e
246254

@@ -306,21 +314,23 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
306314

307315
# Resolve model_id: check if it's in CHAT_MODELS or an alias
308316
model_id = args.model_id
309-
if model_id not in CHAT_MODELS:
310-
# Check if it's an alias
311-
if model_id in self.aliases:
312-
model_id = self.aliases[model_id]
313-
else:
314-
error_msg = f"Model ID '{model_id}' is not a known model or alias. Run '%ai list' to see available models and aliases."
315-
print(error_msg, file=sys.stderr) # Log to stderr
316-
return
317+
# Check if model_id is an alias and get stored configuration
318+
alias_config = None
319+
if model_id not in CHAT_MODELS and model_id in self.aliases:
320+
alias_config = self.aliases[model_id]
321+
model_id = alias_config["target"]
322+
# Use stored api_base and api_key_name if not provided in current call
323+
if not args.api_base and alias_config["api_base"]:
324+
args.api_base = alias_config["api_base"]
325+
if not args.api_key_name and alias_config["api_key_name"]:
326+
args.api_key_name = alias_config["api_key_name"]
327+
elif model_id not in CHAT_MODELS:
328+
error_msg = f"Model ID '{model_id}' is not a known model or alias. Run '%ai list' to see available models and aliases."
329+
print(error_msg, file=sys.stderr) # Log to stderr
330+
return
317331
try:
318332
# Prepare litellm completion arguments
319-
completion_args = {
320-
"model": model_id,
321-
"messages": messages,
322-
"stream": False
323-
}
333+
completion_args = {"model": model_id, "messages": messages, "stream": False}
324334

325335
# Add api_base if provided
326336
if args.api_base:
@@ -493,8 +503,12 @@ def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown:
493503
if args.name in AI_COMMANDS:
494504
raise ValueError(f"The name {args.name} is reserved for a command")
495505

496-
# Store the alias
497-
self.aliases[args.name] = args.target
506+
# Store the alias with its configuration
507+
self.aliases[args.name] = {
508+
"target": args.target,
509+
"api_base": args.api_base,
510+
"api_key_name": args.api_key_name,
511+
}
498512

499513
output = f"Registered new alias `{args.name}`"
500514
return TextOrMarkdown(output, output)
@@ -508,7 +522,7 @@ def handle_version(self, args: VersionArgs) -> str:
508522

509523
def handle_list(self, args: ListArgs):
510524
"""
511-
Handles `%ai list`.
525+
Handles `%ai list`.
512526
- `%ai list` shows all providers by default, and ask the user to run %ai list <provider-name>.
513527
- `%ai list <provider-name>` shows all models available from one provider. It should also note that the list is not comprehensive, and include a reference to the upstream LiteLLM docs.
514528
- `%ai list all` should list all models.
@@ -517,12 +531,12 @@ def handle_list(self, args: ListArgs):
517531
models = CHAT_MODELS
518532

519533
# If provider_id is None, only return provider IDs
520-
if getattr(args, 'provider_id', None) is None:
534+
if getattr(args, "provider_id", None) is None:
521535
# Extract unique provider IDs from model IDs
522536
provider_ids = set()
523537
for model in models:
524-
if '/' in model:
525-
provider_ids.add(model.split('/')[0])
538+
if "/" in model:
539+
provider_ids.add(model.split("/")[0])
526540

527541
# Format output for both text and markdown
528542
text_output = "Available providers\n\n (Run `%ai list <provider_name>` to see models for a specific provider)\n\n"
@@ -533,9 +547,9 @@ def handle_list(self, args: ListArgs):
533547
markdown_output += f"* `{provider_id}`\n"
534548

535549
return TextOrMarkdown(text_output, markdown_output)
536-
537-
elif getattr(args, 'provider_id', None) == 'all':
538-
# Otherwise show all models and aliases
550+
551+
elif getattr(args, "provider_id", None) == "all":
552+
# Otherwise show all models and aliases
539553
text_output = "All available models\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n"
540554
markdown_output = "## All available models \n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n"
541555

@@ -547,12 +561,25 @@ def handle_list(self, args: ListArgs):
547561
if len(self.aliases) > 0:
548562
text_output += "\nAliases:\n"
549563
markdown_output += "\n### Aliases\n\n"
550-
for alias, target in self.aliases.items():
551-
text_output += f"* {alias} -> {target}\n"
552-
markdown_output += f"* `{alias}` -> `{target}`\n"
564+
for alias, config in self.aliases.items():
565+
text_output += f"* {alias}:\n"
566+
text_output += f" - target: {config['target']}\n"
567+
if config["api_base"]:
568+
text_output += f" - api_base: {config['api_base']}\n"
569+
if config["api_key_name"]:
570+
text_output += f" - api_key_name: {config['api_key_name']}\n"
571+
572+
markdown_output += f"* `{alias}`:\n"
573+
markdown_output += f" - target: `{config['target']}`\n"
574+
if config["api_base"]:
575+
markdown_output += f" - api_base: `{config['api_base']}`\n"
576+
if config["api_key_name"]:
577+
markdown_output += (
578+
f" - api_key_name: `{config['api_key_name']}`\n"
579+
)
553580

554581
return TextOrMarkdown(text_output, markdown_output)
555-
582+
556583
else:
557584
# If a specific provider_id is given, filter models by that provider
558585
provider_id = args.provider_id
@@ -575,10 +602,24 @@ def handle_list(self, args: ListArgs):
575602
if len(self.aliases) > 0:
576603
text_output += "\nAliases:\n"
577604
markdown_output += "\n### Aliases\n\n"
578-
for alias, target in self.aliases.items():
579-
if target.startswith(provider_id + "/"):
580-
text_output += f"* {alias} -> {target}\n"
581-
markdown_output += f"* `{alias}` -> `{target}`\n"
582-
605+
for alias, config in self.aliases.items():
606+
if config["target"].startswith(provider_id + "/"):
607+
text_output += f"* {alias}:\n"
608+
text_output += f" - target: {config['target']}\n"
609+
if config["api_base"]:
610+
text_output += f" - api_base: {config['api_base']}\n"
611+
if config["api_key_name"]:
612+
text_output += (
613+
f" - api_key_name: {config['api_key_name']}\n"
614+
)
615+
616+
markdown_output += f"* `{alias}`:\n"
617+
markdown_output += f" - target: `{config['target']}`\n"
618+
if config["api_base"]:
619+
markdown_output += f" - api_base: `{config['api_base']}`\n"
620+
if config["api_key_name"]:
621+
markdown_output += (
622+
f" - api_key_name: `{config['api_key_name']}`\n"
623+
)
583624

584625
return TextOrMarkdown(text_output, markdown_output)

packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ class RegisterArgs(BaseModel):
8888
type: Literal["alias"] = "alias"
8989
name: str
9090
target: str
91+
api_base: Optional[str] = None
92+
api_key_name: Optional[str] = None
9193

9294

9395
class DeleteArgs(BaseModel):
@@ -99,6 +101,8 @@ class UpdateArgs(BaseModel):
99101
type: Literal["update"] = "update"
100102
name: str
101103
target: str
104+
api_base: Optional[str] = None
105+
api_key_name: Optional[str] = None
102106

103107

104108
class ResetArgs(BaseModel):
@@ -292,31 +296,65 @@ def list_subparser(**kwargs):
292296
)
293297
@click.argument("name")
294298
@click.argument("target")
299+
@click.option(
300+
"--api-base",
301+
required=False,
302+
help="Base URL for the API endpoint.",
303+
)
304+
@click.option(
305+
"--api-key-name",
306+
required=False,
307+
help="Name of the environment variable containing the API key.",
308+
)
295309
def register_subparser(**kwargs):
296-
"""Register a new alias called NAME for the model or chain named TARGET."""
310+
"""Register a new alias called NAME for the model or chain named TARGET.
311+
312+
Optional parameters:
313+
--api-base: Base URL for the API endpoint
314+
--api-key-name: Name of the environment variable containing the API key
315+
"""
297316
return RegisterArgs(**kwargs)
298317

299318

300319
@line_magic_parser.command(
301320
name="dealias", short_help="Delete an alias. See `%ai dealias --help` for options."
302321
)
303322
@click.argument("name")
304-
def register_subparser(**kwargs):
323+
def dealias_subparser(**kwargs):
305324
"""Delete an alias called NAME."""
306325
return DeleteArgs(**kwargs)
307326

308327

328+
@line_magic_parser.command(
329+
name="update",
330+
short_help="Update an alias. See `%ai update --help` for options.",
331+
)
309332
@click.argument("name")
310333
@click.argument("target")
311-
def register_subparser(**kwargs):
312-
"""Update an alias called NAME to refer to the model or chain named TARGET."""
334+
@click.option(
335+
"--api-base",
336+
required=False,
337+
help="Base URL for the API endpoint.",
338+
)
339+
@click.option(
340+
"--api-key-name",
341+
required=False,
342+
help="Name of the environment variable containing the API key.",
343+
)
344+
def update_subparser(**kwargs):
345+
"""Update an alias called NAME to refer to the model or chain named TARGET.
346+
347+
Optional parameters:
348+
--api-base: Base URL for the API endpoint
349+
--api-key-name: Name of the environment variable containing the API key
350+
"""
313351
return UpdateArgs(**kwargs)
314352

315353

316354
@line_magic_parser.command(
317355
name="reset",
318356
short_help="Clear the conversation transcript.",
319357
)
320-
def register_subparser(**kwargs):
358+
def reset_subparser(**kwargs):
321359
"""Clear the conversation transcript."""
322360
return ResetArgs()

packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,20 @@ def ip() -> InteractiveShell:
1414

1515

1616
def test_aliases_config(ip):
17-
ip.config.AiMagics.initial_aliases = {"my_custom_alias": "my_provider:my_model"}
17+
ip.config.AiMagics.initial_aliases = {
18+
"my_custom_alias": {
19+
"target": "my_provider:my_model",
20+
"api_base": None,
21+
"api_key_name": None
22+
}
23+
}
1824
ip.extension_manager.load_extension("jupyter_ai_magics")
1925
# Use 'list all' to see all models and aliases
20-
providers_list = ip.run_line_magic("ai", "list all").text
21-
# Check that alias appears in the output
22-
assert "my_custom_alias -> my_provider:my_model" in providers_list
26+
providers_list = ip.run_line_magic("ai", "list all")
27+
# Check that alias appears in the markdown output with correct format
28+
assert "### Aliases" in providers_list.markdown
29+
assert "* `my_custom_alias`:" in providers_list.markdown
30+
assert " - target: `my_provider:my_model`" in providers_list.markdown
2331

2432

2533
def test_default_model_cell(ip):

0 commit comments

Comments
 (0)