diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 1ef341e22..750960332 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -762,7 +762,7 @@ We currently support the following language model providers: To configure a default model you can use the IPython `%config` magic: ```python -%config AiMagics.default_language_model = "anthropic:claude-v1.2" +%config AiMagics.initial_language_model = "anthropic:claude-v1.2" ``` Then subsequent magics can be invoked without typing in the model: @@ -772,10 +772,10 @@ Then subsequent magics can be invoked without typing in the model: Write a poem about C++. ``` -You can configure the default model for all notebooks by specifying `c.AiMagics.default_language_model` tratilet in `ipython_config.py`, for example: +You can configure the default model for all notebooks by specifying `c.AiMagics.initial_language_model` tratilet in `ipython_config.py`, for example: ```python -c.AiMagics.default_language_model = "anthropic:claude-v1.2" +c.AiMagics.initial_language_model = "anthropic:claude-v1.2" ``` The location of `ipython_config.py` file is documented in [IPython configuration reference](https://ipython.readthedocs.io/en/stable/config/intro.html). @@ -965,18 +965,18 @@ produced the following Python error: Write a new version of this code that does not produce that error. ``` -As a shortcut for explaining errors, you can use the `%ai error` command, which will explain the most recent error using the model of your choice. +As a shortcut for explaining and fixing errors, you can use the `%ai fix` command, which will explain the most recent error using the model of your choice. ``` -%ai error anthropic:claude-v1.2 +%ai fix anthropic:claude-v1.2 ``` ### Creating and managing aliases -You can create an alias for a model using the `%ai register` command. For example, the command: +You can create an alias for a model using the `%ai alias` command. For example, the command: ``` -%ai register claude anthropic:claude-v1.2 +%ai alias claude anthropic:claude-v1.2 ``` will register the alias `claude` as pointing to the `anthropic` provider's `claude-v1.2` model. You can then use this alias as you would use any other model name: @@ -1001,10 +1001,10 @@ prompt = PromptTemplate( chain = LLMChain(llm=llm, prompt=prompt) ``` -… and then use `%ai register` to give it a name: +… and then use `%ai alias` to give it a name: ``` -%ai register companyname chain +%ai alias companyname chain ``` You can change an alias's target using the `%ai update` command: @@ -1013,10 +1013,10 @@ You can change an alias's target using the `%ai update` command: %ai update claude anthropic:claude-instant-v1.0 ``` -You can delete an alias using the `%ai delete` command: +You can delete an alias using the `%ai dealias` command: ``` -%ai delete claude +%ai dealias claude ``` You can see a list of all aliases by running the `%ai list` command. @@ -1103,7 +1103,7 @@ the selections they make in the settings panel will take precedence over these v Specify default language model ```bash -jupyter lab --AiExtension.default_language_model=bedrock-chat:anthropic.claude-v2 +jupyter lab --AiExtension.initial_language_model=bedrock-chat:anthropic.claude-v2 ``` Specify default embedding model diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index 0a066bf55..9d0529e0b 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -1,8 +1,9 @@ from __future__ import annotations -from ._version import __version__ from typing import TYPE_CHECKING +from ._version import __version__ + if TYPE_CHECKING: from IPython.core.interactiveshell import InteractiveShell diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index dc9871cc4..2ef4ae6ee 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -3,17 +3,20 @@ import re import sys import warnings +from typing import Any, Optional import click +import litellm import traitlets from IPython.core.magic import Magics, line_cell_magic, magics_class from IPython.display import HTML, JSON, Markdown, Math +from jupyter_ai.model_providers.model_list import CHAT_MODELS from ._version import __version__ from .parsers import ( CellArgs, DeleteArgs, - ErrorArgs, + FixArgs, HelpArgs, ListArgs, RegisterArgs, @@ -24,10 +27,6 @@ line_magic_parser, ) -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Any, Dict, Optional class TextOrMarkdown: def __init__(self, text, markdown): @@ -85,7 +84,7 @@ def _repr_mimebundle_(self, include=None, exclude=None): To see a list of models you can use, run `%ai list`""" -AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"} +AI_COMMANDS = {"dealias", "fix", "help", "list", "alias", "update"} # Strings for listing providers and models # Avoid composing strings, to make localization easier in the future @@ -122,7 +121,7 @@ class AiMagics(Magics): # TODO: rename this to initial_aliases # This should only set the "starting set" of aliases - aliases = traitlets.Dict( + initial_aliases = traitlets.Dict( default_value={}, value_trait=traitlets.Unicode(), key_trait=traitlets.Unicode(), @@ -134,8 +133,7 @@ class AiMagics(Magics): config=True, ) - # TODO: rename this to initial_language_model - default_language_model = traitlets.Unicode( + initial_language_model = traitlets.Unicode( default_value=None, allow_none=True, help="""Default language model to use, as string in the format @@ -179,7 +177,7 @@ def __init__(self, shell): # TODO: use LiteLLM aliases to provide this # https://docs.litellm.ai/docs/completion/model_alias # initialize a registry of custom model/chain names - self.custom_model_registry = self.aliases + self.aliases = self.initial_aliases.copy() @line_cell_magic def ai(self, line: str, cell: Optional[str] = None) -> Any: @@ -190,7 +188,7 @@ def ai(self, line: str, cell: Optional[str] = None) -> Any: - `%ai` is a "line magic command" that only accepts a single line of input. This is used to provide access to sub-commands like `%ai - register`. + alias`. - `%%ai` is a "cell magic command" that accepts an entire cell of input (i.e. multiple lines). This is used to invoke a language model. @@ -200,7 +198,7 @@ def ai(self, line: str, cell: Optional[str] = None) -> Any: method; `%%ai` was run if and only if `cell is not None`. """ raw_args = line.split(" ") - default_map = {"model_id": self.default_language_model} + default_map = {"model_id": self.initial_language_model} # parse arguments if cell: @@ -215,28 +213,26 @@ def ai(self, line: str, cell: Optional[str] = None) -> Any: raw_args, prog_name=r"%ai", standalone_mode=False, - default_map={"error": default_map}, + default_map={"fix": default_map}, ) - if args == 0 and self.default_language_model is None: + if args == 0 and self.initial_language_model is None: # this happens when `--help` is called on the root command, in which # case we want to exit early. return # If a value error occurs, don't print the full stacktrace try: - if args.type == "error": - return self.handle_error(args) + if args.type == "fix": + return self.handle_fix(args) if args.type == "help": return self.handle_help(args) if args.type == "list": return self.handle_list(args) - if args.type == "register": - return self.handle_register(args) - if args.type == "delete": - return self.handle_delete(args) - if args.type == "update": - return self.handle_update(args) + if args.type == "alias": + return self.handle_alias(args) + if args.type == "dealias": + return self.handle_dealias(args) if args.type == "version": return self.handle_version(args) if args.type == "reset": @@ -264,42 +260,57 @@ def run_ai_cell(self, args: CellArgs, prompt: str): Handles the `%%ai` cell magic. This is the main method that invokes the language model. """ - # Apply a prompt template. - # The LLM needs to be given instructions based on `args.format`. See - # `old_prompt_templates.txt`. - # - # We may want to drop some of these formats for simplicity, since they - # don't all seem that helpful. - # - # TODO - prompt = prompt - # Interpolate local variables into prompt. # For example, if a user runs `a = "hello"` and then runs `%%ai {a}`, it # should be equivalent to running `%%ai hello`. ip = self.shell prompt = prompt.format_map(FormatDict(ip.user_ns)) - # TODO: generate the output using LiteLLM - # include `self.transcript` for conversation history - user_message = { - "role": "user", - "content": prompt - } - # ... call litellm.acompletion(, [*self.transcript, user_message]) - # store response as a string in `output` - output: str = "TODO" + # Prepare messages for the model + messages = [] + + # Add conversation history if available + if self.transcript: + messages.extend(self.transcript[-2 * self.max_history :]) - # append exchange to transcript - self._append_exchange(prompt, output) + # Add current prompt + messages.append({"role": "user", "content": prompt}) - # TODO: set model ID in metadata - metadata = {"jupyter_ai_v3": {"model_id": "TODO"}} + # Resolve model_id: check if it's in CHAT_MODELS or an alias + model_id = args.model_id + if model_id not in CHAT_MODELS: + # Check if it's an alias + if model_id in self.aliases: + model_id = self.aliases[model_id] + else: + raise ValueError( + f"Model ID '{model_id}' is not a known model or alias. " + "Run '%ai list' to see available models and aliases." + ) + try: + # Call litellm completion + response = litellm.completion( + model=model_id, messages=messages, stream=False + ) - # Return output given the format - return self.display_output(output, args.format, metadata) + # Extract output text from response + output = response.choices[0].message.content - def display_output(self, output, display_format, metadata: Dict[str, Any]) -> Any: + # Append exchange to transcript + self._append_exchange(prompt, output) + + # Set model ID in metadata + metadata = {"jupyter_ai_v3": {"model_id": args.model_id}} + + # Return output given the format + return self.display_output(output, args.format, metadata) + + except Exception as e: + error_msg = f"Error calling language model: {str(e)}" + print(error_msg, file=sys.stderr) + return error_msg + + def display_output(self, output, display_format, metadata: dict[str, Any]) -> Any: """ Returns an IPython 'display object' that determines how an output is rendered. This is complex, so here are some notes: @@ -315,7 +326,7 @@ def display_output(self, output, display_format, metadata: Dict[str, Any]) -> An Markdown when viewed from a web browser. - See `DISPLAYS_BY_FORMAT` for the list of display objects that can be - returned by `jupyter_ai_magics`. + returned by `jupyter_ai_magics`. TODO: Use a string enum to store the list of valid formats. @@ -352,17 +363,13 @@ def _append_exchange(self, prompt: str, output: str): Appends an exchange between a user and a language model to `self.transcript`. This transcript will be included in future `%ai` calls to preserve conversation history. - - TODO: bound this list to length `self.max_history * 2`. """ - self.transcript.append({ - "role": "user", - "content": prompt - }) - self.transcript.append({ - "role": "assistant", - "content": output - }) + self.transcript.append({"role": "user", "content": prompt}) + self.transcript.append({"role": "assistant", "content": output}) + # Keep only the most recent `self.max_history * 2` messages + max_len = self.max_history * 2 + if len(self.transcript) > max_len: + self.transcript = self.transcript[-max_len:] def handle_help(self, _: HelpArgs) -> None: """ @@ -372,12 +379,9 @@ def handle_help(self, _: HelpArgs) -> None: with click.Context(line_magic_parser, info_name=r"%ai") as ctx: click.echo(line_magic_parser.get_help(ctx)) - - def handle_delete(self, args: DeleteArgs) -> TextOrMarkdown: + def handle_dealias(self, args: DeleteArgs) -> TextOrMarkdown: """ - Handles `%ai delete`. Deletes a model alias. - - TODO: rename the command to `%ai dealias`? + Handles `%ai dealias`. Deletes a model alias. """ if args.name in AI_COMMANDS: @@ -385,10 +389,10 @@ def handle_delete(self, args: DeleteArgs) -> TextOrMarkdown: f"Reserved command names, including {args.name}, cannot be deleted" ) - if args.name not in self.custom_model_registry: + if args.name not in self.aliases: raise ValueError(f"There is no alias called {args.name}") - del self.custom_model_registry[args.name] + del self.aliases[args.name] output = f"Deleted alias `{args.name}`" return TextOrMarkdown(output, output) @@ -398,13 +402,11 @@ def handle_reset(self, args: ResetArgs) -> None: """ self.transcript = [] - def handle_error(self, args: ErrorArgs) -> Any: + def handle_fix(self, args: FixArgs) -> Any: """ - Handles `%ai error`. Meant to provide fixes for any exceptions raised in + Handles `%ai fix`. Meant to provide fixes for any exceptions raised in the kernel while running cells. - TODO: rename this to `%ai fix`? - TODO: annotate a valid return type when we find a type that is shared by all display objects. """ @@ -428,33 +430,28 @@ def handle_error(self, args: ErrorArgs) -> Any: if last_error is None: return TextOrMarkdown(no_errors_message, no_errors_message) - prompt = f"Explain the following error:\n\n{last_error}" - # Set CellArgs based on ErrorArgs + prompt = f"Explain the following error and propose a fix:\n\n{last_error}" + # Set CellArgs based on FixArgs values = args.model_dump() values["type"] = "root" cell_args = CellArgs(**values) + print("I will attempt to explain and fix the error. ") return self.run_ai_cell(cell_args, prompt) - def handle_register(self, args: RegisterArgs) -> TextOrMarkdown: + def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown: """ - Handles `%ai register`. Adds an alias for a model ID for future calls. - - TODO: Use LiteLLM to manage aliases. See - https://docs.litellm.ai/docs/completion/model_alias - - TODO: rename this to `%ai alias`? + Handles `%ai alias`. Adds an alias for a model ID for future calls. """ - pass + # Existing command names are not allowed + if args.name in AI_COMMANDS: + raise ValueError(f"The name {args.name} is reserved for a command") - def handle_update(self, args: UpdateArgs) -> TextOrMarkdown: - """ - Handles `%ai update`. Updates a model alias. + # Store the alias + self.aliases[args.name] = args.target - TODO: remove this command. Users can just delete a model alias and add a - new one. - """ - pass + output = f"Registered new alias `{args.name}`" + return TextOrMarkdown(output, output) def handle_version(self, args: VersionArgs) -> str: """ @@ -466,9 +463,24 @@ def handle_version(self, args: VersionArgs) -> str: def handle_list(self, args: ListArgs): """ Handles `%ai list`. Lists all LiteLLM models. - - The old implementation has been deleted because it was far too complex. - - TODO """ - pass + # Get list of available models from litellm + models = CHAT_MODELS + + # Format output for both text and markdown + text_output = "Available models:\n\n" + markdown_output = "## Available models\n\n" + + for model in models: + text_output += f"* {model}\n" + markdown_output += f"* `{model}`\n" + + # Also list any custom aliases + if len(self.aliases) > 0: + text_output += "\nAliases:\n" + markdown_output += "\n### Aliases\n\n" + for alias, target in self.aliases.items(): + text_output += f"* {alias} -> {target}\n" + markdown_output += f"* `{alias}` -> `{target}`\n" + + return TextOrMarkdown(text_output, markdown_output) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt b/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt index 7ad3ec82a..795a49a67 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt @@ -1,4 +1,4 @@ -Bedrock: +Bedrock: "- For Cross-Region Inference use the appropriate `Inference profile ID` (Model ID with a region prefix, e.g., `us.meta.llama3-2-11b-instruct-v1:0`). See the [inference profiles documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). \n" "- For custom/provisioned models, specify the model ARN (Amazon Resource Name) as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt b/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt index cfb9ff949..8ed599766 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt @@ -29,4 +29,3 @@ model_kwargs["prompt_templates"] = { ), "text": PromptTemplate.from_template("{prompt}"), # No customization } - diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index de99fc8bd..9f8d6254a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -54,8 +54,8 @@ class CellArgs(BaseModel): # Should match CellArgs -class ErrorArgs(BaseModel): - type: Literal["error"] = "error" +class FixArgs(BaseModel): + type: Literal["fix"] = "fix" model_id: str format: FORMAT_CHOICES_TYPE model_parameters: Optional[str] = None @@ -79,13 +79,13 @@ class ListArgs(BaseModel): class RegisterArgs(BaseModel): - type: Literal["register"] = "register" + type: Literal["alias"] = "alias" name: str target: str class DeleteArgs(BaseModel): - type: Literal["delete"] = "delete" + type: Literal["dealias"] = "dealias" name: str @@ -182,7 +182,7 @@ def line_magic_parser(): """ -@line_magic_parser.command(name="error") +@line_magic_parser.command(name="fix") @click.argument("model_id", required=False) @click.option( "-f", @@ -219,14 +219,14 @@ def line_magic_parser(): default="{}", ) @click.pass_context -def error_subparser(context: click.Context, **kwargs): +def fix_subparser(context: click.Context, **kwargs): """ - Explains the most recent error. Takes the same options (except -r) as + Explains and fixes the most recent error. Takes the same options (except -r) as the basic `%%ai` command. """ if not kwargs["model_id"] and context.default_map: - kwargs["model_id"] = context.default_map["error_subparser"]["model_id"] - return ErrorArgs(**kwargs) + kwargs["model_id"] = context.default_map["fix_subparser"]["model_id"] + return FixArgs(**kwargs) @line_magic_parser.command(name="version") @@ -253,8 +253,8 @@ def list_subparser(**kwargs): @line_magic_parser.command( - name="register", - short_help="Register a new alias. See `%ai register --help` for options.", + name="alias", + short_help="Register a new alias. See `%ai alias --help` for options.", ) @click.argument("name") @click.argument("target") @@ -264,7 +264,7 @@ def register_subparser(**kwargs): @line_magic_parser.command( - name="delete", short_help="Delete an alias. See `%ai delete --help` for options." + name="dealias", short_help="Delete an alias. See `%ai dealias --help` for options." ) @click.argument("name") def register_subparser(**kwargs): @@ -272,10 +272,6 @@ def register_subparser(**kwargs): return DeleteArgs(**kwargs) -@line_magic_parser.command( - name="update", - short_help="Update the target of an alias. See `%ai update --help` for options.", -) @click.argument("name") @click.argument("target") def register_subparser(**kwargs): diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py index 7d27e3ffd..a393029d6 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py @@ -14,14 +14,14 @@ def ip() -> InteractiveShell: def test_aliases_config(ip): - ip.config.AiMagics.aliases = {"my_custom_alias": "my_provider:my_model"} + ip.config.AiMagics.initial_aliases = {"my_custom_alias": "my_provider:my_model"} ip.extension_manager.load_extension("jupyter_ai_magics") providers_list = ip.run_line_magic("ai", "list").text assert "my_custom_alias" in providers_list def test_default_model_cell(ip): - ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.config.AiMagics.initial_language_model = "my-favourite-llm" ip.extension_manager.load_extension("jupyter_ai_magics") with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run: ip.run_cell_magic("ai", "", cell="Write code for me please") @@ -31,7 +31,7 @@ def test_default_model_cell(ip): def test_non_default_model_cell(ip): - ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.config.AiMagics.initial_language_model = "my-favourite-llm" ip.extension_manager.load_extension("jupyter_ai_magics") with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run: ip.run_cell_magic("ai", "some-different-llm", cell="Write code for me please") @@ -41,10 +41,10 @@ def test_non_default_model_cell(ip): def test_default_model_error_line(ip): - ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.config.AiMagics.initial_language_model = "my-favourite-llm" ip.extension_manager.load_extension("jupyter_ai_magics") - with patch.object(AiMagics, "handle_error", return_value=None) as mock_run: - ip.run_cell_magic("ai", "error", cell=None) + with patch.object(AiMagics, "handle_fix", return_value=None) as mock_run: + ip.run_cell_magic("ai", "fix", cell=None) assert mock_run.called cell_args = mock_run.call_args.args[0] assert cell_args.model_id == "my-favourite-llm" diff --git a/packages/jupyter-ai/jupyter_ai/__init__.py b/packages/jupyter-ai/jupyter_ai/__init__.py index 973456b42..830a04627 100644 --- a/packages/jupyter-ai/jupyter_ai/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/__init__.py @@ -5,8 +5,10 @@ # expose jupyter_ai_magics ipython extension # DO NOT REMOVE. -from jupyter_ai_magics import load_ipython_extension, unload_ipython_extension - +from jupyter_ai_magics import ( # type: ignore[import-untyped] + load_ipython_extension, + unload_ipython_extension, +) from ._version import __version__ from .extension import AiExtension diff --git a/packages/jupyter-ai/jupyter_ai/completions/completion_utils.py b/packages/jupyter-ai/jupyter_ai/completions/completion_utils.py index e9caf5f4f..150ac40c0 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/completion_utils.py +++ b/packages/jupyter-ai/jupyter_ai/completions/completion_utils.py @@ -49,4 +49,3 @@ def post_process_suggestion(suggestion: str, request: InlineCompletionRequest) - suggestion = suggestion.rstrip()[:-3].rstrip() return suggestion - diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index 6bf47503b..fd83f43d8 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -5,6 +5,9 @@ from typing import Union import tornado +from jupyter_server.base.handlers import JupyterHandler +from pydantic import ValidationError + from ..completion_types import ( CompletionError, InlineCompletionList, @@ -12,13 +15,9 @@ InlineCompletionRequest, InlineCompletionStreamChunk, ) -from jupyter_server.base.handlers import JupyterHandler -from pydantic import ValidationError -class BaseInlineCompletionHandler( - JupyterHandler, tornado.websocket.WebSocketHandler -): +class BaseInlineCompletionHandler(JupyterHandler, tornado.websocket.WebSocketHandler): """A Tornado WebSocket handler that receives inline completion requests and fulfills them accordingly. This class is instantiated once per WebSocket connection.""" diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 1967186a6..93c496a8e 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -15,7 +15,6 @@ async def handle_request(self, request: InlineCompletionRequest): # reply = await llm.generate_inline_completions(request) # self.reply(reply) - pass async def handle_stream_request(self, request: InlineCompletionRequest): # TODO: migrate this to use LiteLLM @@ -27,6 +26,7 @@ async def handle_stream_request(self, request: InlineCompletionRequest): # self.reply(reply) pass + # old methods on BaseProvider, for reference when migrating this to LiteLLM # # async def generate_inline_completions( diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 652034d7a..05bb1efe0 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -108,6 +108,9 @@ def __init__( self._allowed_models = allowed_models self._blocked_models = blocked_models + self._lm_providers: dict[str, Any] = ( + {} + ) # Placeholder: should be set to actual language model providers self._defaults = remove_none_entries(defaults) self._last_read: Optional[int] = None @@ -355,6 +358,19 @@ def chat_model(self) -> str | None: def chat_model_params(self) -> dict[str, Any]: return self._provider_params("model_provider_id", self._lm_providers) + def _provider_params( + self, provider_id_attr: str, providers: dict + ) -> dict[str, Any]: + """ + Returns the parameters for the provider specified by the given attribute. + """ + config = self._read_config() + provider_id = getattr(config, provider_id_attr, None) + if not provider_id or provider_id not in providers: + return {} + return providers[provider_id].get("params", {}) + return self._provider_params("model_provider_id", self._lm_providers) + @property def embedding_model(self) -> str | None: """ @@ -367,7 +383,7 @@ def embedding_model(self) -> str | None: def embedding_model_params(self) -> dict[str, Any]: # TODO return {} - + @property def completion_model(self) -> str | None: """ diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 0657e39f1..29afdd192 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -18,8 +18,6 @@ from traitlets import Integer, List, Type, Unicode from traitlets.config import Config -from .secrets.secrets_rest_api import SecretsRestAPI -from .secrets.secrets_manager import EnvSecretsManager from .completions.handlers import DefaultInlineCompletionHandler from .config_manager import ConfigManager from .handlers import ( @@ -27,6 +25,8 @@ InterruptStreamingHandler, ) from .personas import PersonaManager +from .secrets.secrets_manager import EnvSecretsManager +from .secrets.secrets_rest_api import SecretsRestAPI if TYPE_CHECKING: from asyncio import AbstractEventLoop @@ -51,7 +51,6 @@ JUPYTER_COLLABORATION_EVENTS_URI, ) - from .model_providers.model_handlers import ChatModelEndpoint @@ -141,7 +140,17 @@ class AiExtension(ExtensionApp): config=True, ) - default_language_model = Unicode( + initial_chat_model = Unicode( + default_value=None, + allow_none=True, + help=""" + Default language model to use, as string in the format + :, defaults to None. + """, + config=True, + ) + + initial_language_model = Unicode( default_value=None, allow_none=True, help=""" @@ -302,7 +311,7 @@ def initialize_settings(self): self.log.info(f"Configured model blocklist: {self.blocked_models}") self.log.info(f"Configured model parameters: {self.model_parameters}") defaults = { - "model_provider_id": self.default_language_model, + "model_provider_id": self.initial_language_model, "embeddings_provider_id": self.default_embeddings_model, "completions_model_provider_id": self.default_completions_model, "api_keys": self.default_api_keys, @@ -325,7 +334,7 @@ def initialize_settings(self): # Initialize SecretsManager self.settings["jai_secrets_manager"] = EnvSecretsManager(parent=self) - # Bind event loop to settings dictionary + # Bind event loop to settings dictionary self.settings["jai_event_loop"] = self.event_loop # Bind dictionary of interrupts to settings dictionary. diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 262e79ffe..ff9dae735 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -23,4 +23,3 @@ class ListProvidersEntry(BaseModel): # fields: list[Field] chat_models: Optional[list[str]] = None completion_models: Optional[list[str]] = None - diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py index 461713132..782dd4533 100644 --- a/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py +++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py @@ -1,22 +1,27 @@ from __future__ import annotations -from traitlets.config import LoggingConfigurable -from typing import TYPE_CHECKING -from dotenv import load_dotenv, dotenv_values -from io import StringIO + import asyncio -from tornado.web import HTTPError import os from datetime import datetime +from io import StringIO +from typing import TYPE_CHECKING + +from dotenv import dotenv_values, load_dotenv +from tornado.web import HTTPError +from traitlets.config import LoggingConfigurable -from .secrets_utils import build_updated_dotenv from .secrets_types import SecretsList +from .secrets_utils import build_updated_dotenv if TYPE_CHECKING: - from typing import Any, Optional import logging - from ..extension import AiExtension + from typing import Any + from jupyter_server.services.contents.filemanager import AsyncFileContentsManager + from ..extension import AiExtension + + class EnvSecretsManager(LoggingConfigurable): """ The default secrets manager implementation. @@ -50,7 +55,7 @@ class EnvSecretsManager(LoggingConfigurable): parent class. This annotation exists only to help type checkers like `mypy`. """ - _last_modified: Optional[datetime] + _last_modified: datetime | None """ The 'last modified' timestamp on the '.env' file retrieved in the previous tick of the `_watch_dotenv()` background task. @@ -81,7 +86,7 @@ class EnvSecretsManager(LoggingConfigurable): @property def contents_manager(self) -> AsyncFileContentsManager: return self.parent.serverapp.contents_manager - + @property def event_loop(self) -> asyncio.AbstractEventLoop: return self.parent.event_loop @@ -98,7 +103,6 @@ def __init__(self, *args, **kwargs): # Start `_watch_dotenv()` task to automatically update the environment # variables when `.env` is modified self._watch_dotenv_task = self.event_loop.create_task(self._watch_dotenv()) - async def _watch_dotenv(self) -> None: """ @@ -118,13 +122,13 @@ async def _watch_dotenv(self) -> None: if e.status_code == 404: self._handle_dotenv_notfound() continue - except Exception as e: + except Exception: self.log.exception("Unknown exception in `_watch_dotenv()`:") continue # Continue if the `.env` file was already processed and its content # is unchanged. - if self._last_modified == dotenv_file['last_modified']: + if self._last_modified == dotenv_file["last_modified"]: continue # When this line is reached, the .env file needs to be applied. @@ -132,18 +136,21 @@ async def _watch_dotenv(self) -> None: # in `.gitignore`, and store the latest last modified timestamp. if self._last_modified: # Statement when .env file was modified: - self.log.info("Detected changes to the '.env' file. Re-applying '.env' to the environment...") + self.log.info( + "Detected changes to the '.env' file. Re-applying '.env' to the environment..." + ) else: # Statement when the .env file was just created, or when this is # the first iteration and a .env file already exists: - self.log.info("Detected '.env' file at the workspace root. Applying '.env' to the environment...") + self.log.info( + "Detected '.env' file at the workspace root. Applying '.env' to the environment..." + ) self.event_loop.create_task(self._ensure_dotenv_gitignored()) - self._last_modified = dotenv_file['last_modified'] + self._last_modified = dotenv_file["last_modified"] # Apply the latest `.env` file to the environment. # See `self._apply_dotenv()` for more info. self._apply_dotenv(dotenv_content) - def _apply_dotenv(self, content: str) -> None: """ @@ -160,9 +167,7 @@ def _apply_dotenv(self, content: str) -> None: # Parse the latest `.env` file and store it in `self._dotenv_env`, # tracking deleted environment variables in `deleted_envvars`. new_dotenv_env = dotenv_values(stream=StringIO(content)) - new_dotenv_env = { - k: v for k, v in new_dotenv_env.items() if v != None - } + new_dotenv_env = {k: v for k, v in new_dotenv_env.items() if v != None} deleted_envvars = [k for k in self._dotenv_env if k not in new_dotenv_env] self._dotenv_env = new_dotenv_env @@ -176,8 +181,6 @@ def _apply_dotenv(self, content: str) -> None: load_dotenv(stream=StringIO(content), override=True) self.log.info("Applied '.env' to the environment.") - - async def _ensure_dotenv_gitignored(self) -> bool: """ Ensures the `.env` file is listed in the `.gitignore` file at the @@ -198,10 +201,9 @@ async def _ensure_dotenv_gitignored(self) -> bool: pass else: raise e - except Exception as e: + except Exception: self.log.exception("Unknown exception raised when fetching `.gitignore`:") - pass - + # Return early if the `.gitignore` file exists and already lists `.env`. old_content: str = (gitignore_file or {}).get("content", "") if ".env\n" in old_content: @@ -213,19 +215,18 @@ async def _ensure_dotenv_gitignored(self) -> bool: new_lines = "# Ignore secrets in '.env'\n.env\n" new_content = old_content + "\n" + new_lines if old_content else new_lines try: - gitignore_file = await self.contents_manager.save({ + gitignore_file = await self.contents_manager.save( + { "type": "file", "format": "text", "mimetype": "text/plain", - "content": new_content + "content": new_content, }, - ".gitignore" + ".gitignore", ) - except Exception as e: + except Exception: self.log.exception("Unknown exception raised when updating `.gitignore`:") - pass self.log.info("Updated `.gitignore` file to include `.env`.") - def _reset_envvars(self, names: list[str]) -> None: """ @@ -238,7 +239,6 @@ def _reset_envvars(self, names: list[str]) -> None: os.environ[ev_name] = self._initial_env.get(ev_name) else: del os.environ[ev_name] - def _handle_dotenv_notfound(self) -> None: """ @@ -251,7 +251,6 @@ def _handle_dotenv_notfound(self) -> None: self._reset_envvars(list(self._dotenv_env.keys())) self._dotenv_env = {} - def list_secrets(self) -> SecretsList: """ Lists the names of each environment variable from the workspace `.env` @@ -270,26 +269,25 @@ def list_secrets(self) -> SecretsList: for name in self._initial_env.keys(): if "KEY" in name or "TOKEN" in name or "SECRET" in name: process_secrets_names.add(name) - + # Add secrets from .env, if any for name in self._dotenv_env: dotenv_secrets_names.add(name) - + # Remove `TIKTOKEN_CACHE_DIR`, which is set in the initial environment # by some other package and is not a secret. # This gets included otherwise since it contains 'TOKEN' in its name. process_secrets_names.discard("TIKTOKEN_CACHE_DIR") - + return SecretsList( editable_secrets=sorted(list(dotenv_secrets_names)), - static_secrets=sorted(list(process_secrets_names)) + static_secrets=sorted(list(process_secrets_names)), ) - async def update_secrets( - self, - updated_secrets: dict[str, str | None], - ) -> None: + self, + updated_secrets: dict[str, str | None], + ) -> None: """ Accepts a dictionary of secrets to update, adds/updates/deletes them from `.env` accordingly, and applies the updated `.env` file to the @@ -326,9 +324,11 @@ async def update_secrets( pass else: raise e - except Exception as e: - self.log.exception("Unknown exception raised when reading `.env` in response to an update:") - + except Exception: + self.log.exception( + "Unknown exception raised when reading `.env` in response to an update:" + ) + # Build the new `.env` file using these variables. # See `build_updated_dotenv()` for more info on how this is done. new_dotenv_content = build_updated_dotenv(dotenv_content, updated_secrets) @@ -336,18 +336,21 @@ async def update_secrets( # Return early if no changes are needed in `.env`. if new_dotenv_content is None: return - + # Save new content try: - dotenv_file = await self.contents_manager.save({ - "type": "file", - "format": "text", - "mimetype": "text/plain", - "content": new_dotenv_content - }, ".env") - last_modified = dotenv_file.get('last_modified') + dotenv_file = await self.contents_manager.save( + { + "type": "file", + "format": "text", + "mimetype": "text/plain", + "content": new_dotenv_content, + }, + ".env", + ) + last_modified = dotenv_file.get("last_modified") assert isinstance(last_modified, datetime) - except Exception as e: + except Exception: self.log.exception("Unknown exception raised when updating `.env`:") # If this is a new file, ensure the `.env` file is listed in `.gitignore`. @@ -360,17 +363,14 @@ async def update_secrets( # This automatically sets `self._dotenv_env`. self._apply_dotenv(new_dotenv_content) self.log.info("Updated secrets in `.env`.") - - def get_secret(self, secret_name: str) -> Optional[str]: + def get_secret(self, secret_name: str) -> str | None: """ Returns the value of a secret given its name. The returned secret must NEVER be shared with frontend clients! """ # TODO - pass - def stop(self) -> None: """ Stops this instance and any background tasks spawned by this instance. diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py index bfa082545..c2492f315 100644 --- a/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py +++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py @@ -1,12 +1,16 @@ from __future__ import annotations -from jupyter_server.base.handlers import APIHandler as BaseAPIHandler -from tornado.web import authenticated, HTTPError + from typing import TYPE_CHECKING + +from jupyter_server.base.handlers import APIHandler as BaseAPIHandler +from tornado.web import HTTPError, authenticated + from .secrets_types import UpdateSecretsRequest if TYPE_CHECKING: from .secrets_manager import EnvSecretsManager + class SecretsRestAPI(BaseAPIHandler): """ Defines the REST API served at the `/api/ai/secrets` endpoint. diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py index b0981286a..8959ee8dc 100644 --- a/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py +++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel from typing import Optional +from pydantic import BaseModel + class SecretsList(BaseModel): """ @@ -30,4 +31,5 @@ class UpdateSecretsRequest(BaseModel): """ The request body expected by `PUT /api/ai/secrets`. """ + updated_secrets: dict[str, Optional[str]] diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py index c01a6d897..f4887e708 100644 --- a/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py +++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py @@ -1,11 +1,12 @@ from __future__ import annotations -from dotenv import dotenv_values -from dotenv.parser import parse_stream + from io import StringIO from typing import TYPE_CHECKING +from dotenv import dotenv_values +from dotenv.parser import parse_stream + if TYPE_CHECKING: - from typing import Optional import logging ENVIRONMENT_VAR_REGEX = "^([a-zA-Z_][a-zA-Z_0-9]*)=?(['\"])?" @@ -13,10 +14,11 @@ Regex that matches a environment variable definition. """ + def build_updated_dotenv( dotenv_content: str, updated_secrets: dict[str, str | None], - log: Optional[logging.Logger] = None + log: logging.Logger | None = None, ) -> str | None: """ Accepts the existing `.env` file as a parsed dictionary of environment @@ -59,11 +61,8 @@ def build_updated_dotenv( else: # Case 4: keys can only be added when a `.env` file is not # present. - secrets_to_add = { - k: v for k, v in updated_secrets.items() - if v is not None - } - + secrets_to_add = {k: v for k, v in updated_secrets.items() if v is not None} + # Return early if update has effect. if not (secrets_to_add or secrets_to_update or secrets_to_remove): return None @@ -89,7 +88,7 @@ def build_updated_dotenv( # 1. The `parse_stream()` function returns an Iterator that yields 'Binding' # objects that represent 'parsed chunks' of a `.env` file. Each chunk may # contain: - # + # # - An environment variable definition (`Binding.key is not None`), # - An invalid line (`Binding.error == True`), # - A standalone comment (if neither condition applies). @@ -113,7 +112,9 @@ def build_updated_dotenv( if binding.key in secrets_to_update: name = binding.key # extra logic to preserve formatting as best as we can - whitespace_before, whitespace_after = get_whitespace_around(binding.original.string) + whitespace_before, whitespace_after = get_whitespace_around( + binding.original.string + ) value = secrets_to_update[name] new_content += whitespace_before new_content += f'{name}="{value}"' @@ -124,9 +125,9 @@ def build_updated_dotenv( if secrets_to_add: # Ensure new secrets get put at least 2 lines below the rest - if not new_content.endswith('\n'): + if not new_content.endswith("\n"): new_content += "\n\n" - elif not new_content.endswith('\n\n'): + elif not new_content.endswith("\n\n"): new_content += "\n" max_i = len(secrets_to_add) - 1 @@ -141,17 +142,17 @@ def build_updated_dotenv( def get_whitespace_around(text: str) -> tuple[str, str]: """ Extract whitespace prefix and suffix from a string. - + Args: text: The input string - + Returns: A tuple of (prefix, suffix) where prefix is the leading whitespace and suffix is the trailing whitespace """ if not text: return ("", "") - + # Find prefix (leading whitespace) prefix_end = 0 for i, char in enumerate(text): @@ -161,15 +162,15 @@ def get_whitespace_around(text: str) -> tuple[str, str]: else: # String is all whitespace return (text, "") - + # Find suffix (trailing whitespace) suffix_start = len(text) for i in range(len(text) - 1, -1, -1): if not text[i].isspace(): suffix_start = i + 1 break - + prefix = text[:prefix_end] suffix = text[suffix_start:] - + return (prefix, suffix) diff --git a/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py b/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py index efa480988..2fb4ef2e9 100644 --- a/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py +++ b/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py @@ -3,55 +3,55 @@ class TestGetWhitespaceAround: """Test cases for get_whitespace_around function.""" - + def test_empty_string(self): """Test with empty string.""" prefix, suffix = get_whitespace_around("") assert prefix == "" assert suffix == "" - + def test_no_whitespace(self): """Test with string containing no whitespace.""" prefix, suffix = get_whitespace_around("hello") assert prefix == "" assert suffix == "" - + def test_only_prefix_whitespace(self): """Test with string containing only leading whitespace.""" prefix, suffix = get_whitespace_around(" hello") assert prefix == " " assert suffix == "" - + def test_only_suffix_whitespace(self): """Test with string containing only trailing whitespace.""" prefix, suffix = get_whitespace_around("hello ") assert prefix == "" assert suffix == " " - + def test_both_prefix_and_suffix_whitespace(self): """Test with string containing both leading and trailing whitespace.""" prefix, suffix = get_whitespace_around(" hello ") assert prefix == " " assert suffix == " " - + def test_mixed_whitespace_types(self): """Test with mixed whitespace types (spaces, tabs, newlines).""" prefix, suffix = get_whitespace_around(" \t\nhello\n\t ") assert prefix == " \t\n" assert suffix == "\n\t " - + def test_all_whitespace(self): """Test with string containing only whitespace.""" prefix, suffix = get_whitespace_around(" ") assert prefix == " " assert suffix == "" - + def test_single_character(self): """Test with single non-whitespace character.""" prefix, suffix = get_whitespace_around("x") assert prefix == "" assert suffix == "" - + def test_single_whitespace_character(self): """Test with single whitespace character.""" prefix, suffix = get_whitespace_around(" ") @@ -61,74 +61,67 @@ def test_single_whitespace_character(self): class TestBuildUpdatedDotenv: """Test cases for build_updated_dotenv function.""" - + def test_empty_updates(self): """Test with no updates to make.""" result = build_updated_dotenv("KEY=value", {}) assert result is None - + def test_add_to_empty_dotenv(self): """Test adding secrets to empty dotenv content.""" result = build_updated_dotenv("", {"NEW_KEY": "new_value"}) assert result == 'NEW_KEY="new_value"\n' - + def test_add_multiple_to_empty_dotenv(self): """Test adding multiple secrets to empty dotenv content.""" - result = build_updated_dotenv("", { - "KEY1": "value1", - "KEY2": "value2" - }) - expected_lines = result.strip().split('\n') + result = build_updated_dotenv("", {"KEY1": "value1", "KEY2": "value2"}) + expected_lines = result.strip().split("\n") assert len(expected_lines) == 3 # Two keys plus one empty line assert 'KEY1="value1"' in expected_lines assert 'KEY2="value2"' in expected_lines - + def test_update_existing_key(self): """Test updating an existing key.""" dotenv_content = 'EXISTING_KEY="old_value"\n' result = build_updated_dotenv(dotenv_content, {"EXISTING_KEY": "new_value"}) assert 'EXISTING_KEY="new_value"' in result - + def test_add_new_key_to_existing_dotenv(self): """Test adding a new key to existing dotenv content.""" dotenv_content = 'EXISTING_KEY="existing_value"\n' result = build_updated_dotenv(dotenv_content, {"NEW_KEY": "new_value"}) assert 'EXISTING_KEY="existing_value"' in result assert 'NEW_KEY="new_value"' in result - + def test_remove_existing_key(self): """Test removing an existing key.""" dotenv_content = 'KEY_TO_REMOVE="value"\nKEY_TO_KEEP="value"\n' result = build_updated_dotenv(dotenv_content, {"KEY_TO_REMOVE": None}) assert "KEY_TO_REMOVE" not in result assert 'KEY_TO_KEEP="value"' in result - + def test_mixed_operations(self): """Test adding, updating, and removing keys in one operation.""" dotenv_content = 'UPDATE_ME="old"\nREMOVE_ME="gone"\nKEEP_ME="same"\n' - updates = { - "UPDATE_ME": "new", - "REMOVE_ME": None, - "ADD_ME": "added" - } + updates = {"UPDATE_ME": "new", "REMOVE_ME": None, "ADD_ME": "added"} result = build_updated_dotenv(dotenv_content, updates) - + assert 'UPDATE_ME="new"' in result assert "REMOVE_ME" not in result assert 'KEEP_ME="same"' in result assert 'ADD_ME="added"' in result - + def test_preserve_comments_and_empty_lines(self): """Test that comments and empty lines are preserved.""" dotenv_content = '# This is a comment\nKEY="value"\n\n# Another comment\n' result = build_updated_dotenv(dotenv_content, {"NEW_KEY": "new_value"}) - + assert "# This is a comment" in result assert "# Another comment" in result assert 'KEY="value"' in result assert 'NEW_KEY="new_value"' in result - + def test_delete_last_secret(self): - dotenv_content="KEY='value'" + dotenv_content = "KEY='value'" result = build_updated_dotenv(dotenv_content, {"KEY": None}) - assert isinstance(result, str) and result.strip() == "" \ No newline at end of file + assert isinstance(result, str) and result.strip() == "" diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index bde013430..b8b33045d 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -1,14 +1,15 @@ import json + # from types import SimpleNamespace from typing import Union import pytest -from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler from jupyter_ai.completions.completion_types import ( InlineCompletionReply, InlineCompletionRequest, InlineCompletionStreamChunk, ) +from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler from pytest import fixture from tornado.httputil import HTTPServerRequest from tornado.web import Application diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py index 2b9ae3bcd..16948b663 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -1,6 +1,5 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from unittest import mock import pytest from jupyter_ai.extension import AiExtension @@ -55,4 +54,4 @@ def jp_server_config(jp_server_config): @pytest.fixture def ai_extension(jp_serverapp): - ai = AiExtension() + AiExtension()