diff --git a/.gitignore b/.gitignore
index d156b06d5..d90ad5b40 100644
--- a/.gitignore
+++ b/.gitignore
@@ -142,3 +142,6 @@ packages/**/_version.py
# Ignore local chat files & local .jupyter/ dir
*.chat
.jupyter/
+
+# Ignore secrets in '.env'
+.env
diff --git a/docs/source/users/index.md b/docs/source/users/index.md
index 1ef341e22..750960332 100644
--- a/docs/source/users/index.md
+++ b/docs/source/users/index.md
@@ -762,7 +762,7 @@ We currently support the following language model providers:
To configure a default model you can use the IPython `%config` magic:
```python
-%config AiMagics.default_language_model = "anthropic:claude-v1.2"
+%config AiMagics.initial_language_model = "anthropic:claude-v1.2"
```
Then subsequent magics can be invoked without typing in the model:
@@ -772,10 +772,10 @@ Then subsequent magics can be invoked without typing in the model:
Write a poem about C++.
```
-You can configure the default model for all notebooks by specifying `c.AiMagics.default_language_model` tratilet in `ipython_config.py`, for example:
+You can configure the default model for all notebooks by specifying `c.AiMagics.initial_language_model` tratilet in `ipython_config.py`, for example:
```python
-c.AiMagics.default_language_model = "anthropic:claude-v1.2"
+c.AiMagics.initial_language_model = "anthropic:claude-v1.2"
```
The location of `ipython_config.py` file is documented in [IPython configuration reference](https://ipython.readthedocs.io/en/stable/config/intro.html).
@@ -965,18 +965,18 @@ produced the following Python error:
Write a new version of this code that does not produce that error.
```
-As a shortcut for explaining errors, you can use the `%ai error` command, which will explain the most recent error using the model of your choice.
+As a shortcut for explaining and fixing errors, you can use the `%ai fix` command, which will explain the most recent error using the model of your choice.
```
-%ai error anthropic:claude-v1.2
+%ai fix anthropic:claude-v1.2
```
### Creating and managing aliases
-You can create an alias for a model using the `%ai register` command. For example, the command:
+You can create an alias for a model using the `%ai alias` command. For example, the command:
```
-%ai register claude anthropic:claude-v1.2
+%ai alias claude anthropic:claude-v1.2
```
will register the alias `claude` as pointing to the `anthropic` provider's `claude-v1.2` model. You can then use this alias as you would use any other model name:
@@ -1001,10 +1001,10 @@ prompt = PromptTemplate(
chain = LLMChain(llm=llm, prompt=prompt)
```
-… and then use `%ai register` to give it a name:
+… and then use `%ai alias` to give it a name:
```
-%ai register companyname chain
+%ai alias companyname chain
```
You can change an alias's target using the `%ai update` command:
@@ -1013,10 +1013,10 @@ You can change an alias's target using the `%ai update` command:
%ai update claude anthropic:claude-instant-v1.0
```
-You can delete an alias using the `%ai delete` command:
+You can delete an alias using the `%ai dealias` command:
```
-%ai delete claude
+%ai dealias claude
```
You can see a list of all aliases by running the `%ai list` command.
@@ -1103,7 +1103,7 @@ the selections they make in the settings panel will take precedence over these v
Specify default language model
```bash
-jupyter lab --AiExtension.default_language_model=bedrock-chat:anthropic.claude-v2
+jupyter lab --AiExtension.initial_language_model=bedrock-chat:anthropic.claude-v2
```
Specify default embedding model
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
index b535d4a88..9d0529e0b 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
@@ -1,89 +1,20 @@
+from __future__ import annotations
+
from typing import TYPE_CHECKING
-from ._import_utils import import_attr as _import_attr
from ._version import __version__
if TYPE_CHECKING:
- # same as dynamic imports but understood by mypy
- from .embedding_providers import (
- BaseEmbeddingsProvider,
- GPT4AllEmbeddingsProvider,
- HfHubEmbeddingsProvider,
- QianfanEmbeddingsEndpointProvider,
- )
- from .exception import store_exception
- from .magics import AiMagics
- from .providers import (
- AI21Provider,
- BaseProvider,
- GPT4AllProvider,
- HfHubProvider,
- QianfanProvider,
- TogetherAIProvider,
- )
-else:
- _exports_by_module = {
- "embedding_providers": [
- "BaseEmbeddingsProvider",
- "GPT4AllEmbeddingsProvider",
- "HfHubEmbeddingsProvider",
- "QianfanEmbeddingsEndpointProvider",
- ],
- "exception": ["store_exception"],
- "magics": ["AiMagics"],
- # expose model providers on the package root
- "providers": [
- "AI21Provider",
- "BaseProvider",
- "GPT4AllProvider",
- "HfHubProvider",
- "QianfanProvider",
- "TogetherAIProvider",
- ],
- }
+ from IPython.core.interactiveshell import InteractiveShell
- _modules_by_export = {
- import_name: module
- for module, imports in _exports_by_module.items()
- for import_name in imports
- }
-
- def __getattr__(export_name: str) -> object:
- module_name = _modules_by_export.get(export_name)
- result = _import_attr(export_name, module_name, __spec__.parent)
- globals()[export_name] = result
- return result
+def load_ipython_extension(ipython: InteractiveShell):
+ from .exception import store_exception
+ from .magics import AiMagics
-def load_ipython_extension(ipython):
- ipython.register_magics(__getattr__("AiMagics"))
- ipython.set_custom_exc((BaseException,), __getattr__("store_exception"))
+ ipython.register_magics(AiMagics)
+ ipython.set_custom_exc((BaseException,), store_exception)
-def unload_ipython_extension(ipython):
+def unload_ipython_extension(ipython: InteractiveShell):
ipython.set_custom_exc((BaseException,), ipython.CustomTB)
-
-
-# required to preserve backward compatibility with `from jupyter_ai_magics import *`
-__all__ = [
- "__version__",
- "load_ipython_extension",
- "unload_ipython_extension",
- "BaseEmbeddingsProvider",
- "GPT4AllEmbeddingsProvider",
- "HfHubEmbeddingsProvider",
- "QianfanEmbeddingsEndpointProvider",
- "store_exception",
- "AiMagics",
- "AI21Provider",
- "BaseProvider",
- "GPT4AllProvider",
- "HfHubProvider",
- "QianfanProvider",
- "TogetherAIProvider",
-]
-
-
-def __dir__():
- # Allows more editors (e.g. IPython) to complete on `jupyter_ai_magics.`
- return list(__all__)
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/_import_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/_import_utils.py
deleted file mode 100644
index f251e452c..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/_import_utils.py
+++ /dev/null
@@ -1,58 +0,0 @@
-"""
-MIT License
-
-Copyright (c) LangChain, Inc.
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-"""
-
-from importlib import import_module
-from typing import Union
-
-
-def import_attr(
- attr_name: str,
- module_name: Union[str, None],
- package: Union[str, None],
-) -> object:
- """Import an attribute from a module located in a package.
-
- This utility function is used in custom __getattr__ methods within __init__.py
- files to dynamically import attributes.
-
- Args:
- attr_name: The name of the attribute to import.
- module_name: The name of the module to import from. If None, the attribute
- is imported from the package itself.
- package: The name of the package where the module is located.
- """
- if module_name == "__module__" or module_name is None:
- try:
- result = import_module(f".{attr_name}", package=package)
- except ModuleNotFoundError:
- msg = f"module '{package!r}' has no attribute {attr_name!r}"
- raise AttributeError(msg) from None
- else:
- try:
- module = import_module(f".{module_name}", package=package)
- except ModuleNotFoundError:
- msg = f"module '{package!r}.{module_name!r}' not found"
- raise ImportError(msg) from None
- result = getattr(module, attr_name)
- return result
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
deleted file mode 100644
index e34a31aa0..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
+++ /dev/null
@@ -1,10 +0,0 @@
-MODEL_ID_ALIASES = {
- "gpt2": "huggingface_hub:gpt2",
- "gpt3": "openai:davinci-002",
- "chatgpt": "openai-chat:gpt-3.5-turbo",
- "gpt4": "openai-chat:gpt-4",
- "ernie-bot": "qianfan:ERNIE-Bot",
- "ernie-bot-4": "qianfan:ERNIE-Bot-4",
- "titan": "bedrock:amazon.titan-tg1-large",
- "openrouter-claude": "openrouter:anthropic/claude-3.5-sonnet:beta",
-}
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/base_provider.py b/packages/jupyter-ai-magics/jupyter_ai_magics/base_provider.py
deleted file mode 100644
index 8884aa15d..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/base_provider.py
+++ /dev/null
@@ -1,479 +0,0 @@
-import asyncio
-import functools
-from collections.abc import AsyncIterator, Coroutine
-from concurrent.futures import ThreadPoolExecutor
-from types import MappingProxyType
-from typing import (
- Any,
- ClassVar,
- Literal,
- Optional,
- Union,
-)
-
-from langchain.prompts import (
- ChatPromptTemplate,
- HumanMessagePromptTemplate,
- MessagesPlaceholder,
- PromptTemplate,
- SystemMessagePromptTemplate,
-)
-from langchain.schema import LLMResult
-from langchain.schema.output_parser import StrOutputParser
-from langchain.schema.runnable import Runnable
-from langchain_core.language_models.chat_models import BaseChatModel
-from langchain_core.language_models.llms import BaseLLM
-from pydantic import BaseModel, ConfigDict
-
-from . import completion_utils as completion
-from .models.completion import (
- InlineCompletionList,
- InlineCompletionReply,
- InlineCompletionRequest,
- InlineCompletionStreamChunk,
-)
-
-CHAT_SYSTEM_PROMPT = """
-You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
-You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}.
-You are talkative and you provide lots of specific details from the foundation model's context.
-You may use Markdown to format your response.
-If your response includes code, they must be enclosed in Markdown fenced code blocks (with triple backticks before and after).
-If your response includes mathematical notation, they must be expressed in LaTeX markup and enclosed in LaTeX delimiters.
-All dollar quantities (of USD) must be formatted in LaTeX, with the `$` symbol escaped by a single backslash `\\`.
-- Example prompt: `If I have \\\\$100 and spend \\\\$20, how much money do I have left?`
-- **Correct** response: `You have \\(\\$80\\) remaining.`
-- **Incorrect** response: `You have $80 remaining.`
-If you do not know the answer to a question, answer truthfully by responding that you do not know.
-The following is a friendly conversation between you and a human.
-""".strip()
-
-CHAT_DEFAULT_TEMPLATE = """
-{% if context %}
-Context:
-{{context}}
-
-{% endif %}
-Current conversation:
-{{history}}
-Human: {{input}}
-AI:"""
-
-HUMAN_MESSAGE_TEMPLATE = """
-{% if context %}
-Context:
-{{context}}
-
-{% endif %}
-{{input}}
-"""
-
-COMPLETION_SYSTEM_PROMPT = """
-You are an application built to provide helpful code completion suggestions.
-You should only produce code. Keep comments to minimum, use the
-programming language comment syntax. Produce clean code.
-The code is written in JupyterLab, a data analysis and code development
-environment which can execute code extended with additional syntax for
-interactive features, such as magics.
-""".strip()
-
-# only add the suffix bit if present to save input tokens/computation time
-COMPLETION_DEFAULT_TEMPLATE = """
-The document is called `{{filename}}` and written in {{language}}.
-{% if suffix %}
-The code after the completion request is:
-
-```
-{{suffix}}
-```
-{% endif %}
-
-Complete the following code:
-
-```
-{{prefix}}"""
-
-
-class EnvAuthStrategy(BaseModel):
- """
- Describes a provider that uses a single authentication token, which is
- passed either as an environment variable or as a keyword argument.
- """
-
- type: Literal["env"] = "env"
-
- name: str
- """The name of the environment variable, e.g. `'ANTHROPIC_API_KEY'`."""
-
- keyword_param: Optional[str] = None
- """
- If unset (default), the authentication token is provided as a keyword
- argument with the parameter equal to the environment variable name in
- lowercase. If set to some string `k`, the authentication token will be
- passed using the keyword parameter `k`.
- """
-
-
-class MultiEnvAuthStrategy(BaseModel):
- """Require multiple auth tokens via multiple environment variables."""
-
- type: Literal["multienv"] = "multienv"
- names: list[str]
-
-
-class AwsAuthStrategy(BaseModel):
- """Require AWS authentication via Boto3"""
-
- type: Literal["aws"] = "aws"
-
-
-AuthStrategy = Optional[
- Union[
- EnvAuthStrategy,
- MultiEnvAuthStrategy,
- AwsAuthStrategy,
- ]
-]
-
-
-class Field(BaseModel):
- key: str
- label: str
- # "text" accepts any text
- format: Literal["json", "jsonpath", "text"]
-
-
-class TextField(Field):
- type: Literal["text"] = "text"
-
-
-class MultilineTextField(Field):
- type: Literal["text-multiline"] = "text-multiline"
-
-
-class IntegerField(BaseModel):
- type: Literal["integer"] = "integer"
- key: str
- label: str
-
-
-Field = Union[TextField, MultilineTextField, IntegerField]
-
-
-class BaseProvider(BaseModel):
- # pydantic v2 model config
- # upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra
- model_config = ConfigDict(extra="allow")
-
- #
- # class attrs
- #
- id: ClassVar[str] = ...
- """ID for this provider class."""
-
- name: ClassVar[str] = ...
- """User-facing name of this provider."""
-
- models: ClassVar[list[str]] = ...
- """List of supported models by their IDs. For registry providers, this will
- be just ["*"]."""
-
- help: ClassVar[Optional[str]] = None
- """Text to display in lieu of a model list for a registry provider that does
- not provide a list of models."""
-
- model_id_key: ClassVar[Optional[str]] = None
- """
- Optional field which specifies the key under which `model_id` is passed to
- the parent LangChain class.
-
- If unset, this defaults to "model_id".
- """
-
- model_id_label: ClassVar[Optional[str]] = None
- """
- Optional field which sets the label shown in the UI allowing users to
- select/type a model ID.
-
- If unset, the label shown in the UI defaults to "Model ID".
- """
-
- pypi_package_deps: ClassVar[list[str]] = []
- """List of PyPi package dependencies."""
-
- auth_strategy: ClassVar[AuthStrategy] = None
- """Authentication/authorization strategy. Declares what credentials are
- required to use this model provider. Generally should not be `None`."""
-
- registry: ClassVar[bool] = False
- """Whether this provider is a registry provider."""
-
- fields: ClassVar[list[Field]] = []
- """User inputs expected by this provider when initializing it. Each `Field` `f`
- should be passed in the constructor as a keyword argument, keyed by `f.key`."""
-
- manages_history: ClassVar[bool] = False
- """Whether this provider manages its own conversation history upstream. If
- set to `True`, Jupyter AI will not pass the chat history to this provider
- when invoked."""
-
- unsupported_slash_commands: ClassVar[set] = set()
- """
- A set of slash commands unsupported by this provider. Unsupported slash
- commands are not shown in the help message, and cannot be used while this
- provider is selected.
- """
-
- server_settings: ClassVar[Optional[MappingProxyType]] = None
- """Settings passed on from jupyter-ai package.
-
- The same server settings are shared between all providers.
- Providers are not allowed to mutate this dictionary.
- """
-
- @classmethod
- def chat_models(self):
- """Models which are suitable for chat."""
- return self.models
-
- @classmethod
- def completion_models(self):
- """Models which are suitable for completions."""
- return self.models
-
- #
- # instance attrs
- #
- model_id: str
- prompt_templates: dict[str, PromptTemplate]
- """Prompt templates for each output type. Can be overridden with
- `update_prompt_template`. The function `prompt_template`, in the base class,
- refers to this."""
-
- def __init__(self, *args, **kwargs):
- try:
- assert kwargs["model_id"]
- except:
- raise AssertionError(
- "model_id was not specified. Please specify it as a keyword argument."
- )
-
- model_kwargs = {}
- if self.__class__.model_id_key != "model_id":
- model_kwargs[self.__class__.model_id_key] = kwargs["model_id"]
-
- model_kwargs["prompt_templates"] = {
- "code": PromptTemplate.from_template(
- "{prompt}\n\nProduce output as source code only, "
- "with no text or explanation before or after it."
- ),
- "html": PromptTemplate.from_template(
- "{prompt}\n\nProduce output in HTML format only, "
- "with no markup before or afterward."
- ),
- "image": PromptTemplate.from_template(
- "{prompt}\n\nProduce output as an image only, "
- "with no text before or after it."
- ),
- "markdown": PromptTemplate.from_template(
- "{prompt}\n\nProduce output in markdown format only."
- ),
- "md": PromptTemplate.from_template(
- "{prompt}\n\nProduce output in markdown format only."
- ),
- "math": PromptTemplate.from_template(
- "{prompt}\n\nProduce output in LaTeX format only, "
- "with $$ at the beginning and end."
- ),
- "json": PromptTemplate.from_template(
- "{prompt}\n\nProduce output in JSON format only, "
- "with nothing before or after it."
- ),
- "text": PromptTemplate.from_template("{prompt}"), # No customization
- }
- super().__init__(*args, **kwargs, **model_kwargs)
-
- async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
- """
- Calls self._call() asynchronously in a separate thread for providers
- without an async implementation. Requires the event loop to be running.
- """
- executor = ThreadPoolExecutor(max_workers=1)
- loop = asyncio.get_running_loop()
- _call_with_args = functools.partial(self._call, *args, **kwargs)
- return await loop.run_in_executor(executor, _call_with_args)
-
- async def _generate_in_executor(
- self, *args, **kwargs
- ) -> Coroutine[Any, Any, LLMResult]:
- """
- Calls self._generate() asynchronously in a separate thread for providers
- without an async implementation. Requires the event loop to be running.
- """
- executor = ThreadPoolExecutor(max_workers=1)
- loop = asyncio.get_running_loop()
- _call_with_args = functools.partial(self._generate, *args, **kwargs)
- return await loop.run_in_executor(executor, _call_with_args)
-
- @classmethod
- def is_api_key_exc(cls, _: Exception):
- """
- Determine if the exception is an API key error. Can be implemented by subclasses.
- """
- return False
-
- def update_prompt_template(self, format: str, template: str):
- """
- Changes the class-level prompt template for a given format.
- """
- self.prompt_templates[format] = PromptTemplate.from_template(template)
-
- def get_prompt_template(self, format) -> PromptTemplate:
- """
- Produce a prompt template suitable for use with a particular model, to
- produce output in a desired format.
- """
-
- if format in self.prompt_templates:
- return self.prompt_templates[format]
- else:
- return self.prompt_templates["text"] # Default to plain format
-
- def get_chat_prompt_template(self) -> PromptTemplate:
- """
- Produce a prompt template optimised for chat conversation.
- The template should take two variables: history and input.
- """
- name = self.__class__.name
- if self.is_chat_provider:
- return ChatPromptTemplate.from_messages(
- [
- SystemMessagePromptTemplate.from_template(
- CHAT_SYSTEM_PROMPT
- ).format(provider_name=name, local_model_id=self.model_id),
- MessagesPlaceholder(variable_name="history"),
- HumanMessagePromptTemplate.from_template(
- HUMAN_MESSAGE_TEMPLATE,
- template_format="jinja2",
- ),
- ]
- )
- else:
- return PromptTemplate(
- input_variables=["history", "input", "context"],
- template=CHAT_SYSTEM_PROMPT.format(
- provider_name=name, local_model_id=self.model_id
- )
- + "\n\n"
- + CHAT_DEFAULT_TEMPLATE,
- template_format="jinja2",
- )
-
- def get_completion_prompt_template(self) -> PromptTemplate:
- """
- Produce a prompt template optimised for inline code or text completion.
- The template should take variables: prefix, suffix, language, filename.
- """
- if self.is_chat_provider:
- return ChatPromptTemplate.from_messages(
- [
- SystemMessagePromptTemplate.from_template(COMPLETION_SYSTEM_PROMPT),
- HumanMessagePromptTemplate.from_template(
- COMPLETION_DEFAULT_TEMPLATE, template_format="jinja2"
- ),
- ]
- )
- else:
- return PromptTemplate(
- input_variables=["prefix", "suffix", "language", "filename"],
- template=COMPLETION_SYSTEM_PROMPT
- + "\n\n"
- + COMPLETION_DEFAULT_TEMPLATE,
- template_format="jinja2",
- )
-
- @property
- def is_chat_provider(self):
- return isinstance(self, BaseChatModel)
-
- @property
- def allows_concurrency(self):
- return True
-
- @property
- def _supports_sync_streaming(self):
- if self.is_chat_provider:
- return not (self.__class__._stream is BaseChatModel._stream)
- else:
- return not (self.__class__._stream is BaseLLM._stream)
-
- @property
- def _supports_async_streaming(self):
- if self.is_chat_provider:
- return not (self.__class__._astream is BaseChatModel._astream)
- else:
- return not (self.__class__._astream is BaseLLM._astream)
-
- @property
- def supports_streaming(self):
- return self._supports_sync_streaming or self._supports_async_streaming
-
- async def generate_inline_completions(
- self, request: InlineCompletionRequest
- ) -> InlineCompletionReply:
- chain = self._create_completion_chain()
- model_arguments = completion.template_inputs_from_request(request)
- suggestion = await chain.ainvoke(input=model_arguments)
- suggestion = completion.post_process_suggestion(suggestion, request)
- return InlineCompletionReply(
- list=InlineCompletionList(items=[{"insertText": suggestion}]),
- reply_to=request.number,
- )
-
- async def stream_inline_completions(
- self, request: InlineCompletionRequest
- ) -> AsyncIterator[InlineCompletionStreamChunk]:
- chain = self._create_completion_chain()
- token = completion.token_from_request(request, 0)
- model_arguments = completion.template_inputs_from_request(request)
- suggestion = processed_suggestion = ""
-
- # send an incomplete `InlineCompletionReply`, indicating to the
- # client that LLM output is about to streamed across this connection.
- yield InlineCompletionReply(
- list=InlineCompletionList(
- items=[
- {
- # insert text starts empty as we do not pre-generate any part
- "insertText": "",
- "isIncomplete": True,
- "token": token,
- }
- ]
- ),
- reply_to=request.number,
- )
-
- async for fragment in chain.astream(input=model_arguments):
- suggestion += fragment
- processed_suggestion = completion.post_process_suggestion(
- suggestion, request
- )
- yield InlineCompletionStreamChunk(
- type="stream",
- response={"insertText": processed_suggestion, "token": token},
- reply_to=request.number,
- done=False,
- )
-
- # finally, send a message confirming that we are done
- yield InlineCompletionStreamChunk(
- type="stream",
- response={"insertText": processed_suggestion, "token": token},
- reply_to=request.number,
- done=True,
- )
-
- def _create_completion_chain(self) -> Runnable:
- prompt_template = self.get_completion_prompt_template()
- return prompt_template | self | StrOutputParser()
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
deleted file mode 100644
index 38053635f..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py
+++ /dev/null
@@ -1,119 +0,0 @@
-from typing import ClassVar, Optional
-
-from jupyter_ai_magics.base_provider import (
- AuthStrategy,
- EnvAuthStrategy,
- Field,
- MultiEnvAuthStrategy,
-)
-from langchain_community.embeddings import (
- GPT4AllEmbeddings,
- HuggingFaceHubEmbeddings,
- QianfanEmbeddingsEndpoint,
-)
-from pydantic import BaseModel, ConfigDict
-
-
-class BaseEmbeddingsProvider(BaseModel):
- """Base class for embedding providers"""
-
- # pydantic v2 model config
- # upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra
- model_config = ConfigDict(extra="allow")
-
- id: ClassVar[str] = ...
- """ID for this provider class."""
-
- name: ClassVar[str] = ...
- """User-facing name of this provider."""
-
- models: ClassVar[list[str]] = ...
- """List of supported models by their IDs. For registry providers, this will
- be just ["*"]."""
-
- help: ClassVar[Optional[str]] = None
- """Text to display in lieu of a model list for a registry provider that does
- not provide a list of models."""
-
- model_id_key: ClassVar[str] = ...
- """Kwarg expected by the upstream LangChain provider."""
-
- pypi_package_deps: ClassVar[list[str]] = []
- """List of PyPi package dependencies."""
-
- auth_strategy: ClassVar[AuthStrategy] = None
- """Authentication/authorization strategy. Declares what credentials are
- required to use this model provider. Generally should not be `None`."""
-
- model_id: str
-
- registry: ClassVar[bool] = False
- """Whether this provider is a registry provider."""
-
- fields: ClassVar[list[Field]] = []
- """Fields expected by this provider in its constructor. Each `Field` `f`
- should be passed as a keyword argument, keyed by `f.key`."""
-
- def __init__(self, *args, **kwargs):
- try:
- assert kwargs["model_id"]
- except:
- raise AssertionError(
- "model_id was not specified. Please specify it as a keyword argument."
- )
-
- model_kwargs = {}
- if self.__class__.model_id_key != "model_id":
- model_kwargs[self.__class__.model_id_key] = kwargs["model_id"]
-
- super().__init__(*args, **kwargs, **model_kwargs)
-
-
-class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings):
- id = "huggingface_hub"
- name = "Hugging Face Hub"
- models = ["*"]
- model_id_key = "repo_id"
- help = (
- "See [https://huggingface.co/docs/chat-ui/en/configuration/embeddings](https://huggingface.co/docs/chat-ui/en/configuration/embeddings) for reference. "
- "Pass an embedding model's name; for example, `sentence-transformers/all-MiniLM-L6-v2`."
- )
- # ipywidgets needed to suppress tqdm warning
- # https://stackoverflow.com/questions/67998191
- # tqdm is a dependency of huggingface_hub
- pypi_package_deps = ["huggingface_hub", "ipywidgets"]
- auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
- registry = True
-
-
-class GPT4AllEmbeddingsProvider(BaseEmbeddingsProvider, GPT4AllEmbeddings):
- def __init__(self, **kwargs):
- from gpt4all import GPT4All
-
- model_name = kwargs.get("model_id").split(":")[-1]
-
- # GPT4AllEmbeddings doesn't allow any kwargs at the moment
- # This will cause the class to start downloading the model
- # if the model file is not present. Calling retrieve_model
- # here will throw an exception if the file is not present.
- GPT4All.retrieve_model(model_name=model_name, allow_download=False)
-
- kwargs["allow_download"] = False
- super().__init__(**kwargs)
-
- id = "gpt4all"
- name = "GPT4All Embeddings"
- models = ["all-MiniLM-L6-v2-f16"]
- model_id_key = "model_id"
- pypi_package_deps = ["gpt4all"]
-
-
-class QianfanEmbeddingsEndpointProvider(
- BaseEmbeddingsProvider, QianfanEmbeddingsEndpoint
-):
- id = "qianfan"
- name = "ERNIE-Bot"
- models = ["ERNIE-Bot", "ERNIE-Bot-4"]
- model_id_key = "model"
- pypi_package_deps = ["qianfan"]
- auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
index 5f94c2ecf..88ba32feb 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
@@ -1,28 +1,22 @@
import base64
import json
-import keyword
-import os
import re
import sys
import warnings
-from typing import Optional
+from typing import Any, Optional
import click
+import litellm
import traitlets
from IPython.core.magic import Magics, line_cell_magic, magics_class
from IPython.display import HTML, JSON, Markdown, Math
-from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
-from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers
-from langchain.chains import LLMChain
-from langchain.schema import HumanMessage
-from langchain_core.messages import AIMessage
+from jupyter_ai.model_providers.model_list import CHAT_MODELS
from ._version import __version__
-from .base_provider import BaseProvider
from .parsers import (
CellArgs,
DeleteArgs,
- ErrorArgs,
+ FixArgs,
HelpArgs,
ListArgs,
RegisterArgs,
@@ -90,7 +84,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
To see a list of models you can use, run `%ai list`"""
-AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"}
+AI_COMMANDS = {"dealias", "fix", "help", "list", "alias", "update"}
# Strings for listing providers and models
# Avoid composing strings, to make localization easier in the future
@@ -125,8 +119,10 @@ class CellMagicError(BaseException):
@magics_class
class AiMagics(Magics):
- aliases = traitlets.Dict(
- default_value=MODEL_ID_ALIASES,
+ # TODO: rename this to initial_aliases
+ # This should only set the "starting set" of aliases
+ initial_aliases = traitlets.Dict(
+ default_value={},
value_trait=traitlets.Unicode(),
key_trait=traitlets.Unicode(),
help="""Aliases for model identifiers.
@@ -137,7 +133,7 @@ class AiMagics(Magics):
config=True,
)
- default_language_model = traitlets.Unicode(
+ initial_language_model = traitlets.Unicode(
default_value=None,
allow_none=True,
help="""Default language model to use, as string in the format
@@ -155,17 +151,20 @@ class AiMagics(Magics):
config=True,
)
+ transcript: list[dict[str, str]]
+ """
+ The conversation history as a list of messages. Each message is a simple
+ dictionary with the following structure:
+
+ - `"role"`: `"user"`, `"assistant"`, or `"system"`
+ - `"content"`: the content of the message
+ """
+
def __init__(self, shell):
super().__init__(shell)
self.transcript = []
- # suppress warning when using old Anthropic provider
- warnings.filterwarnings(
- "ignore",
- message="This Anthropic LLM is deprecated. Please use "
- "`from langchain.chat_models import ChatAnthropic` instead",
- )
-
+ # TODO: check if this is necessary
# suppress warning about our exception handler
warnings.filterwarnings(
"ignore",
@@ -175,304 +174,176 @@ def __init__(self, shell):
"show full tracebacks.",
)
- self.providers = get_lm_providers()
-
+ # TODO: use LiteLLM aliases to provide this
+ # https://docs.litellm.ai/docs/completion/model_alias
# initialize a registry of custom model/chain names
- self.custom_model_registry = self.aliases
-
- def _ai_bulleted_list_models_for_provider(self, provider_id, Provider):
- output = ""
- if len(Provider.models) == 1 and Provider.models[0] == "*":
- if Provider.help is None:
- output += f"* {PROVIDER_NO_MODELS}\n"
- else:
- output += f"* {Provider.help}\n"
- else:
- for model_id in Provider.models:
- output += f"* {provider_id}:{model_id}\n"
- output += "\n" # End of bulleted list
-
- return output
+ self.aliases = self.initial_aliases.copy()
- def _ai_inline_list_models_for_provider(self, provider_id, Provider):
- output = "
"
+ @line_cell_magic
+ def ai(self, line: str, cell: Optional[str] = None) -> Any:
+ """
+ Defines how `%ai` and `%%ai` magic commands are handled. This is called
+ first whenever either `%ai` or `%%ai` is run, so it should be considered
+ the main method of the `AiMagics` class.
+
+ - `%ai` is a "line magic command" that only accepts a single line of
+ input. This is used to provide access to sub-commands like `%ai
+ alias`.
+
+ - `%%ai` is a "cell magic command" that accepts an entire cell of input
+ (i.e. multiple lines). This is used to invoke a language model.
+
+ This method is called when either `%ai` or `%%ai` is run. Whether a line
+ or cell magic was run can be determined by the arguments given to this
+ method; `%%ai` was run if and only if `cell is not None`.
+ """
+ raw_args = line.split(" ")
+ default_map = {"model_id": self.initial_language_model}
- if len(Provider.models) == 1 and Provider.models[0] == "*":
- if Provider.help is None:
- return PROVIDER_NO_MODELS
+ # parse arguments
+ args = None
+ try:
+ if cell:
+ args = cell_magic_parser(
+ raw_args,
+ prog_name=r"%%ai",
+ standalone_mode=False,
+ default_map={"cell_magic_parser": default_map},
+ )
else:
- return Provider.help
-
- for model_id in Provider.models:
- output += f"
`{provider_id}:{model_id}`
"
-
- return output + "
"
-
- # Is the required environment variable set?
- def _ai_env_status_for_provider_markdown(self, provider_id):
- na_message = "Not applicable. | " + NA_MESSAGE
-
- if (
- provider_id not in self.providers
- or self.providers[provider_id].auth_strategy == None
- ):
- return na_message # No emoji
-
- not_set_title = ENV_NOT_SET
- set_title = ENV_SET
- env_status_ok = False
-
- auth_strategy = self.providers[provider_id].auth_strategy
- if auth_strategy.type == "env":
- var_name = auth_strategy.name
- env_var_display = f"`{var_name}`"
- env_status_ok = var_name in os.environ
- elif auth_strategy.type == "multienv":
- # Check multiple environment variables
- var_names = self.providers[provider_id].auth_strategy.names
- formatted_names = [f"`{name}`" for name in var_names]
- env_var_display = ", ".join(formatted_names)
- env_status_ok = all(var_name in os.environ for var_name in var_names)
- not_set_title = MULTIENV_NOT_SET
- set_title = MULTIENV_SET
- else: # No environment variables
- return na_message
-
- output = f"{env_var_display} | "
- if env_status_ok:
- output += f'✅'
- else:
- output += f'❌'
-
- return output
-
- def _ai_env_status_for_provider_text(self, provider_id):
- # only handle providers with "env" or "multienv" auth strategy
- auth_strategy = getattr(self.providers[provider_id], "auth_strategy", None)
- if not auth_strategy or (
- auth_strategy.type != "env" and auth_strategy.type != "multienv"
- ):
- return ""
-
- prefix = ENV_REQUIRES if auth_strategy.type == "env" else MULTIENV_REQUIRES
- envvars = (
- [auth_strategy.name]
- if auth_strategy.type == "env"
- else auth_strategy.names[:]
- )
-
- for i in range(len(envvars)):
- envvars[i] += " (set)" if envvars[i] in os.environ else " (not set)"
-
- return prefix + " " + ", ".join(envvars) + "\n"
-
- # Is this a name of a Python variable that can be called as a LangChain chain?
- def _is_langchain_chain(self, name):
- # Reserved word in Python?
- if keyword.iskeyword(name):
- return False
+ args = line_magic_parser(
+ raw_args,
+ prog_name=r"%ai",
+ standalone_mode=False,
+ default_map={"fix": default_map},
+ )
+ except Exception as e:
+ if "model_id" in str(e) and "string_type" in str(e):
+ error_msg = "No Model ID entered, please enter it in the following format: `%%ai `"
+ print(error_msg, file=sys.stderr)
+ return
+ if not args:
+ print("No valid %ai magics arguments given, run `%ai help` for all options.", file=sys.stderr)
+ return
+ raise e
+
+ if args == 0 and self.initial_language_model is None:
+ # this happens when `--help` is called on the root command, in which
+ # case we want to exit early.
+ return
- acceptable_name = re.compile("^[a-zA-Z0-9_]+$")
- if not acceptable_name.match(name):
- return False
+ # If a value error occurs, don't print the full stacktrace
+ try:
+ if args.type == "fix":
+ return self.handle_fix(args)
+ if args.type == "help":
+ return self.handle_help(args)
+ if args.type == "list":
+ return self.handle_list(args)
+ if args.type == "alias":
+ return self.handle_alias(args)
+ if args.type == "dealias":
+ return self.handle_dealias(args)
+ if args.type == "version":
+ return self.handle_version(args)
+ if args.type == "reset":
+ return self.handle_reset(args)
+ except ValueError as e:
+ print(e, file=sys.stderr)
+ return
- ipython = self.shell
- return name in ipython.user_ns and isinstance(ipython.user_ns[name], LLMChain)
+ # hint to the IDE that this object must be of type `CellArgs`
+ args: CellArgs = args
- # Is this an acceptable name for an alias?
- def _validate_name(self, register_name):
- # A registry name contains ASCII letters, numbers, hyphens, underscores,
- # and periods. No other characters, including a colon, are permitted
- acceptable_name = re.compile("^[a-zA-Z0-9._-]+$")
- if not acceptable_name.match(register_name):
- raise ValueError(
- "A registry name may contain ASCII letters, numbers, hyphens, underscores, "
- + "and periods. No other characters, including a colon, are permitted"
+ if not cell:
+ raise CellMagicError(
+ """To invoke a language model, you must use the `%%ai`
+ cell magic. The `%ai` line magic is only for use with
+ subcommands."""
)
- # Initially set or update an alias to a target
- def _safely_set_target(self, register_name, target):
- # If target is a string, treat this as an alias to another model.
- if self._is_langchain_chain(target):
- ip = self.shell
- self.custom_model_registry[register_name] = ip.user_ns[target]
- else:
- # Ensure that the destination is properly formatted
- if ":" not in target:
- raise ValueError(
- "Target model must be an LLMChain object or a model name in PROVIDER_ID:MODEL_NAME format"
- )
+ prompt = cell.strip()
- self.custom_model_registry[register_name] = target
+ return self.run_ai_cell(args, prompt)
- def handle_delete(self, args: DeleteArgs):
- if args.name in AI_COMMANDS:
- raise ValueError(
- f"Reserved command names, including {args.name}, cannot be deleted"
- )
+ def run_ai_cell(self, args: CellArgs, prompt: str):
+ """
+ Handles the `%%ai` cell magic. This is the main method that invokes the
+ language model.
+ """
+ # Interpolate local variables into prompt.
+ # For example, if a user runs `a = "hello"` and then runs `%%ai {a}`, it
+ # should be equivalent to running `%%ai hello`.
+ ip = self.shell
+ prompt = prompt.format_map(FormatDict(ip.user_ns))
- if args.name not in self.custom_model_registry:
- raise ValueError(f"There is no alias called {args.name}")
+ # Prepare messages for the model
+ messages = []
- del self.custom_model_registry[args.name]
- output = f"Deleted alias `{args.name}`"
- return TextOrMarkdown(output, output)
+ # Add conversation history if available
+ if self.transcript:
+ messages.extend(self.transcript[-2 * self.max_history :])
- def handle_register(self, args: RegisterArgs):
- # Existing command names are not allowed
- if args.name in AI_COMMANDS:
- raise ValueError(f"The name {args.name} is reserved for a command")
+ # Add current prompt
+ messages.append({"role": "user", "content": prompt})
- # Existing registered names are not allowed
- if args.name in self.custom_model_registry:
- raise ValueError(
- f"The name {args.name} is already associated with a custom model; "
- + "use %ai update to change its target"
+ # Resolve model_id: check if it's in CHAT_MODELS or an alias
+ model_id = args.model_id
+ if model_id not in CHAT_MODELS:
+ # Check if it's an alias
+ if model_id in self.aliases:
+ model_id = self.aliases[model_id]
+ else:
+ error_msg = f"Model ID '{model_id}' is not a known model or alias. Run '%ai list' to see available models and aliases."
+ print(error_msg, file=sys.stderr) # Log to stderr
+ return
+ try:
+ # Call litellm completion
+ response = litellm.completion(
+ model=model_id, messages=messages, stream=False
)
- # Does the new name match expected format?
- self._validate_name(args.name)
+ # Extract output text from response
+ output = response.choices[0].message.content
- self._safely_set_target(args.name, args.target)
- output = f"Registered new alias `{args.name}`"
- return TextOrMarkdown(output, output)
+ # Append exchange to transcript
+ self._append_exchange(prompt, output)
- def handle_update(self, args: UpdateArgs):
- if args.name in AI_COMMANDS:
- raise ValueError(
- f"Reserved command names, including {args.name}, cannot be updated"
- )
+ # Set model ID in metadata
+ metadata = {"jupyter_ai_v3": {"model_id": args.model_id}}
- if args.name not in self.custom_model_registry:
- raise ValueError(f"There is no alias called {args.name}")
+ # Return output given the format
+ return self.display_output(output, args.format, metadata)
- self._safely_set_target(args.name, args.target)
- output = f"Updated target of alias `{args.name}`"
- return TextOrMarkdown(output, output)
+ except Exception as e:
+ error_msg = f"Error calling language model: {str(e)}"
+ print(error_msg, file=sys.stderr)
+ return error_msg
- def _ai_list_command_markdown(self, single_provider=None):
- output = (
- "| Provider | Environment variable | Set? | Models |\n"
- + "|----------|----------------------|------|--------|\n"
- )
- if single_provider is not None and single_provider not in self.providers:
- return f"There is no model provider with ID `{single_provider}`."
-
- for provider_id, Provider in self.providers.items():
- if single_provider is not None and provider_id != single_provider:
- continue
-
- output += (
- f"| `{provider_id}` | "
- + self._ai_env_status_for_provider_markdown(provider_id)
- + " | "
- + self._ai_inline_list_models_for_provider(provider_id, Provider)
- + " |\n"
- )
+ def display_output(self, output, display_format, metadata: dict[str, Any]) -> Any:
+ """
+ Returns an IPython 'display object' that determines how an output is
+ rendered. This is complex, so here are some notes:
- # Also list aliases.
- if single_provider is None and len(self.custom_model_registry) > 0:
- output += (
- "\nAliases and custom commands:\n\n"
- + "| Name | Target |\n"
- + "|------|--------|\n"
- )
- for key, value in self.custom_model_registry.items():
- output += f"| `{key}` | "
- if isinstance(value, str):
- output += f"`{value}`"
- else:
- output += "*custom chain*"
-
- output += " |\n"
-
- return output
-
- def _ai_list_command_text(self, single_provider=None):
- output = ""
- if single_provider is not None and single_provider not in self.providers:
- return f"There is no model provider with ID '{single_provider}'."
-
- for provider_id, Provider in self.providers.items():
- if single_provider is not None and provider_id != single_provider:
- continue
-
- output += (
- f"{provider_id}\n"
- + self._ai_env_status_for_provider_text(
- provider_id
- ) # includes \n if nonblank
- + self._ai_bulleted_list_models_for_provider(provider_id, Provider)
- )
+ - The display object returned is controlled by the `display_format`
+ argument. See `DISPLAYS_BY_FORMAT` for the list of valid formats.
- # Also list aliases.
- if single_provider is None and len(self.custom_model_registry) > 0:
- output += "\nAliases and custom commands:\n"
- for key, value in self.custom_model_registry.items():
- output += f"{key} - "
- if isinstance(value, str):
- output += value
- else:
- output += "custom chain"
+ - In most use-cases, this method returns a `TextOrMarkdown` object. The
+ reason this exists is because IPython may be run from a terminal shell
+ (via the `ipython` command) or from a web browser in a Jupyter Notebook.
- output += "\n"
+ - `TextOrMarkdown` shows text when viewed from a command line, and rendered
+ Markdown when viewed from a web browser.
- return output
+ - See `DISPLAYS_BY_FORMAT` for the list of display objects that can be
+ returned by `jupyter_ai_magics`.
- def handle_error(self, args: ErrorArgs):
- no_errors = "There have been no errors since the kernel started."
+ TODO: Use a string enum to store the list of valid formats.
- # Find the most recent error.
- ip = self.shell
- if "Err" not in ip.user_ns:
- return TextOrMarkdown(no_errors, no_errors)
-
- err = ip.user_ns["Err"]
- # Start from the previous execution count
- excount = ip.execution_count - 1
- last_error = None
- while excount >= 0 and last_error is None:
- if excount in err:
- last_error = err[excount]
- else:
- excount = excount - 1
-
- if last_error is None:
- return TextOrMarkdown(no_errors, no_errors)
-
- prompt = f"Explain the following error:\n\n{last_error}"
- # Set CellArgs based on ErrorArgs
- values = args.model_dump()
- values["type"] = "root"
- cell_args = CellArgs(**values)
-
- return self.run_ai_cell(cell_args, prompt)
-
- def _append_exchange(self, prompt: str, output: str):
- """Appends a conversational exchange between user and an OpenAI Chat
- model to a transcript that will be included in future exchanges."""
- self.transcript.append(HumanMessage(prompt))
- self.transcript.append(AIMessage(output))
-
- def _decompose_model_id(self, model_id: str):
- """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
- # custom_model_registry maps keys to either a model name (a string) or an LLMChain.
- # If this is an alias to another model, expand the full name of the model.
- if model_id in self.custom_model_registry and isinstance(
- self.custom_model_registry[model_id], str
- ):
- model_id = self.custom_model_registry[model_id]
-
- return decompose_model_id(model_id, self.providers)
-
- def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
- """Returns the model provider ID and class for a model ID. Returns None if indeterminate."""
- if provider_id is None or provider_id not in self.providers:
- return None
-
- return self.providers[provider_id]
-
- def display_output(self, output, display_format, md):
+ TODO: What is the shared type that all display objects implement? We
+ implement `_repr_mime_()` but that doesn't seem to be implemented on all
+ display objects. So the return type is `Any` for now.
+ """
# build output display
DisplayClass = DISPLAYS_BY_FORMAT[display_format]
@@ -484,7 +355,7 @@ def display_output(self, output, display_format, md):
output = re.sub(r"\n```$", "", output)
self.shell.set_next_input(output, replace=False)
return HTML(
- "AI generated code inserted below ⬇️", metadata=md
+ "AI generated code inserted below ⬇️", metadata=metadata
)
if DisplayClass is None:
@@ -492,194 +363,176 @@ def display_output(self, output, display_format, md):
if display_format == "json":
# JSON display expects a dict, not a JSON string
output = json.loads(output)
- output_display = DisplayClass(output, metadata=md)
+ output_display = DisplayClass(output, metadata=metadata)
# finally, display output display
return output_display
- def handle_help(self, _: HelpArgs):
+ def _append_exchange(self, prompt: str, output: str):
+ """
+ Appends an exchange between a user and a language model to
+ `self.transcript`. This transcript will be included in future `%ai`
+ calls to preserve conversation history.
+ """
+ self.transcript.append({"role": "user", "content": prompt})
+ self.transcript.append({"role": "assistant", "content": output})
+ # Keep only the most recent `self.max_history * 2` messages
+ max_len = self.max_history * 2
+ if len(self.transcript) > max_len:
+ self.transcript = self.transcript[-max_len:]
+
+ def handle_help(self, _: HelpArgs) -> None:
+ """
+ Handles `%ai help`. Prints a help message via `click.echo()`.
+ """
# The line parser's help function prints both cell and line help
- with click.Context(line_magic_parser, info_name="%ai") as ctx:
+ with click.Context(line_magic_parser, info_name=r"%ai") as ctx:
click.echo(line_magic_parser.get_help(ctx))
- def handle_list(self, args: ListArgs):
- return TextOrMarkdown(
- self._ai_list_command_text(args.provider_id),
- self._ai_list_command_markdown(args.provider_id),
- )
-
- def handle_version(self, args: VersionArgs):
- return __version__
-
- def handle_reset(self, args: ResetArgs):
- self.transcript = []
+ def handle_dealias(self, args: DeleteArgs) -> TextOrMarkdown:
+ """
+ Handles `%ai dealias`. Deletes a model alias.
+ """
- def run_ai_cell(self, args: CellArgs, prompt: str):
- provider_id, local_model_id = self._decompose_model_id(args.model_id)
-
- # If this is a custom chain, send the message to the custom chain.
- if args.model_id in self.custom_model_registry and isinstance(
- self.custom_model_registry[args.model_id], LLMChain
- ):
- # Get the output, either as raw text or as the contents of the 'text' key of a dict
- invoke_output = self.custom_model_registry[args.model_id].invoke(prompt)
- if isinstance(invoke_output, dict):
- invoke_output = invoke_output.get("text")
-
- return self.display_output(
- invoke_output,
- args.format,
- {"jupyter_ai": {"custom_chain_id": args.model_id}},
+ if args.name in AI_COMMANDS:
+ raise ValueError(
+ f"Reserved command names, including {args.name}, cannot be deleted"
)
- Provider = self._get_provider(provider_id)
- if Provider is None:
- return TextOrMarkdown(
- CANNOT_DETERMINE_MODEL_TEXT.format(args.model_id)
- + "\n\n"
- + "If you were trying to run a command, run '%ai help' to see a list of commands.",
- CANNOT_DETERMINE_MODEL_MARKDOWN.format(args.model_id)
- + "\n\n"
- + "If you were trying to run a command, run `%ai help` to see a list of commands.",
- )
+ if args.name not in self.aliases:
+ raise ValueError(f"There is no alias called {args.name}")
- # validate presence of authn credentials
- auth_strategy = self.providers[provider_id].auth_strategy
- if auth_strategy:
- if auth_strategy.type == "env" and auth_strategy.name not in os.environ:
- raise OSError(
- f"Authentication environment variable {auth_strategy.name} is not set.\n"
- f"An authentication token is required to use models from the {Provider.name} provider.\n"
- f"Please specify it via `%env {auth_strategy.name}=token`. "
- ) from None
- if auth_strategy.type == "multienv":
- # Multiple environment variables must be set
- missing_vars = [
- var for var in auth_strategy.names if var not in os.environ
- ]
- raise OSError(
- f"Authentication environment variables {missing_vars} are not set.\n"
- f"Multiple authentication tokens are required to use models from the {Provider.name} provider.\n"
- f"Please specify them all via `%env` commands. "
- ) from None
-
- # configure and instantiate provider
- provider_params = {"model_id": local_model_id}
- # for SageMaker, validate that required params are specified
- if provider_id == "sagemaker-endpoint":
- if (
- args.region_name is None
- or args.request_schema is None
- or args.response_path is None
- ):
- raise ValueError(
- "When using the sagemaker-endpoint provider, you must specify all of "
- + "the --region-name, --request-schema, and --response-path options."
- )
- provider_params["region_name"] = args.region_name
- provider_params["request_schema"] = args.request_schema
- provider_params["response_path"] = args.response_path
+ del self.aliases[args.name]
+ output = f"Deleted alias `{args.name}`"
+ return TextOrMarkdown(output, output)
- model_parameters = json.loads(args.model_parameters)
+ def handle_reset(self, args: ResetArgs) -> None:
+ """
+ Handles `%ai reset`. Clears the history.
+ """
+ self.transcript = []
- provider = Provider(**provider_params, **model_parameters)
+ def handle_fix(self, args: FixArgs) -> Any:
+ """
+ Handles `%ai fix`. Meant to provide fixes for any exceptions raised in
+ the kernel while running cells.
- # Apply a prompt template.
- prompt = provider.get_prompt_template(args.format).format(prompt=prompt)
+ TODO: annotate a valid return type when we find a type that is shared by
+ all display objects.
+ """
+ no_errors_message = "There have been no errors since the kernel started."
- # interpolate user namespace into prompt
+ # Find the most recent error.
ip = self.shell
- prompt = prompt.format_map(FormatDict(ip.user_ns))
+ if "Err" not in ip.user_ns:
+ return TextOrMarkdown(no_errors_message, no_errors_message)
- context = self.transcript[-2 * self.max_history :] if self.max_history else []
- if provider.is_chat_provider:
- result = provider.generate([[*context, HumanMessage(content=prompt)]])
- else:
- # generate output from model via provider
- if context:
- inputs = "\n\n".join(
- [
- (
- f"AI: {message.content}"
- if message.type == "ai"
- else f"{message.type.title()}: {message.content}"
- )
- for message in context + [HumanMessage(content=prompt)]
- ]
- )
+ err = ip.user_ns["Err"]
+ # Start from the previous execution count
+ excount = ip.execution_count - 1
+ last_error = None
+ while excount >= 0 and last_error is None:
+ if excount in err:
+ last_error = err[excount]
else:
- inputs = prompt
- result = provider.generate([inputs])
-
- output = result.generations[0][0].text
+ excount = excount - 1
- # append exchange to transcript
- self._append_exchange(prompt, output)
+ if last_error is None:
+ return TextOrMarkdown(no_errors_message, no_errors_message)
- md = {"jupyter_ai": {"provider_id": provider_id, "model_id": local_model_id}}
+ prompt = f"Explain the following error and propose a fix:\n\n{last_error}"
+ # Set CellArgs based on FixArgs
+ values = args.model_dump()
+ values["type"] = "root"
+ cell_args = CellArgs(**values)
+ print("I will attempt to explain and fix the error. ")
- return self.display_output(output, args.format, md)
+ return self.run_ai_cell(cell_args, prompt)
- @line_cell_magic
- def ai(self, line, cell=None):
- raw_args = line.split(" ")
- default_map = {"model_id": self.default_language_model}
- if cell:
- args = cell_magic_parser(
- raw_args,
- prog_name="%%ai",
- standalone_mode=False,
- default_map={"cell_magic_parser": default_map},
- )
- else:
- args = line_magic_parser(
- raw_args,
- prog_name="%ai",
- standalone_mode=False,
- default_map={"error": default_map},
- )
+ def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown:
+ """
+ Handles `%ai alias`. Adds an alias for a model ID for future calls.
+ """
+ # Existing command names are not allowed
+ if args.name in AI_COMMANDS:
+ raise ValueError(f"The name {args.name} is reserved for a command")
- if args == 0 and self.default_language_model is None:
- # this happens when `--help` is called on the root command, in which
- # case we want to exit early.
- return
+ # Store the alias
+ self.aliases[args.name] = args.target
- # If a value error occurs, don't print the full stacktrace
- try:
- if args.type == "error":
- return self.handle_error(args)
- if args.type == "help":
- return self.handle_help(args)
- if args.type == "list":
- return self.handle_list(args)
- if args.type == "register":
- return self.handle_register(args)
- if args.type == "delete":
- return self.handle_delete(args)
- if args.type == "update":
- return self.handle_update(args)
- if args.type == "version":
- return self.handle_version(args)
- if args.type == "reset":
- return self.handle_reset(args)
- except ValueError as e:
- print(e, file=sys.stderr)
- return
+ output = f"Registered new alias `{args.name}`"
+ return TextOrMarkdown(output, output)
- # hint to the IDE that this object must be of type `RootArgs`
- args: CellArgs = args
+ def handle_version(self, args: VersionArgs) -> str:
+ """
+ Handles `%ai version`. Returns the current version of
+ `jupyter_ai_magics`.
+ """
+ return __version__
- if not cell:
- raise CellMagicError(
- """[0.8+]: To invoke a language model, you must use the `%%ai`
- cell magic. The `%ai` line magic is only for use with
- subcommands."""
- )
+ def handle_list(self, args: ListArgs):
+ """
+ Handles `%ai list`.
+ - `%ai list` shows all providers by default, and ask the user to run %ai list .
+ - `%ai list ` shows all models available from one provider. It should also note that the list is not comprehensive, and include a reference to the upstream LiteLLM docs.
+ - `%ai list all` should list all models.
+ """
+ # Get list of available models from litellm
+ models = CHAT_MODELS
+
+ # If provider_id is None, only return provider IDs
+ if getattr(args, 'provider_id', None) is None:
+ # Extract unique provider IDs from model IDs
+ provider_ids = set()
+ for model in models:
+ if '/' in model:
+ provider_ids.add(model.split('/')[0])
+
+ # Format output for both text and markdown
+ text_output = "Available providers\n\n (Run `%ai list ` to see models for a specific provider)\n\n"
+ markdown_output = "## Available providers\n\n (Run `%ai list ` to see models for a specific provider)\n\n"
+
+ for provider_id in sorted(provider_ids):
+ text_output += f"* {provider_id}\n"
+ markdown_output += f"* `{provider_id}`\n"
+
+ return TextOrMarkdown(text_output, markdown_output)
+
+ elif getattr(args, 'provider_id', None) == 'all':
+ # Otherwise show all models and aliases
+ text_output = "All available models\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n"
+ markdown_output = "## All available models \n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n"
+
+ for model in models:
+ text_output += f"* {model}\n"
+ markdown_output += f"* `{model}`\n"
+
+ # Also list any custom aliases
+ if len(self.aliases) > 0:
+ text_output += "\nAliases:\n"
+ markdown_output += "\n### Aliases\n\n"
+ for alias, target in self.aliases.items():
+ text_output += f"* {alias} -> {target}\n"
+ markdown_output += f"* `{alias}` -> `{target}`\n"
+
+ return TextOrMarkdown(text_output, markdown_output)
+
+ else:
+ # If a specific provider_id is given, filter models by that provider
+ provider_id = args.provider_id
+ filtered_models = [m for m in models if m.startswith(provider_id + "/")]
+
+ if not filtered_models:
+ return TextOrMarkdown(
+ f"No models found for provider '{provider_id}'.",
+ f"No models found for provider `{provider_id}`.",
+ )
- prompt = cell.strip()
+ text_output = f"Available models for provider '{provider_id}'\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers/{provider_id})\n\n"
+ markdown_output = f"## Available models for provider `{provider_id}`\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers/{provider_id})\n\n"
- # interpolate user namespace into prompt
- ip = self.shell
- prompt = prompt.format_map(FormatDict(ip.user_ns))
+ for model in filtered_models:
+ text_output += f"* {model}\n"
+ markdown_output += f"* `{model}`\n"
- return self.run_ai_cell(args, prompt)
+ return TextOrMarkdown(text_output, markdown_output)
\ No newline at end of file
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt b/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt
new file mode 100644
index 000000000..795a49a67
--- /dev/null
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt
@@ -0,0 +1,20 @@
+Bedrock:
+
+"- For Cross-Region Inference use the appropriate `Inference profile ID` (Model ID with a region prefix, e.g., `us.meta.llama3-2-11b-instruct-v1:0`). See the [inference profiles documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). \n"
+"- For custom/provisioned models, specify the model ARN (Amazon Resource Name) as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n"
+"The model provider must also be specified below. This is the provider of your foundation model *in lowercase*, e.g., `amazon`, `anthropic`, `cohere`, `meta`, or `mistral`."
+
+SageMaker Endpoints:
+
+"Specify an endpoint name as the model ID. "
+"In addition, you must specify a region name, request schema, and response path. "
+"For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deploy-models.html) "
+"and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)."
+
+Vertex AI:
+
+"To use Vertex AI Generative AI you must have the langchain-google-vertexai Python package installed and either:\n\n"
+"- Have credentials configured for your environment (gcloud, workload identity, etc...)\n"
+"- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n\n"
+"This codebase uses the google.auth library which first looks for the application credentials variable mentioned above, and then looks for system-level auth. "
+"For more information, see the [Vertex AI authentication documentation](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm/)."
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt b/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt
new file mode 100644
index 000000000..8ed599766
--- /dev/null
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt
@@ -0,0 +1,31 @@
+# for magics
+
+model_kwargs["prompt_templates"] = {
+ "code": PromptTemplate.from_template(
+ "{prompt}\n\nProduce output as source code only, "
+ "with no text or explanation before or after it."
+ ),
+ "html": PromptTemplate.from_template(
+ "{prompt}\n\nProduce output in HTML format only, "
+ "with no markup before or afterward."
+ ),
+ "image": PromptTemplate.from_template(
+ "{prompt}\n\nProduce output as an image only, "
+ "with no text before or after it."
+ ),
+ "markdown": PromptTemplate.from_template(
+ "{prompt}\n\nProduce output in markdown format only."
+ ),
+ "md": PromptTemplate.from_template(
+ "{prompt}\n\nProduce output in markdown format only."
+ ),
+ "math": PromptTemplate.from_template(
+ "{prompt}\n\nProduce output in LaTeX format only, "
+ "with $$ at the beginning and end."
+ ),
+ "json": PromptTemplate.from_template(
+ "{prompt}\n\nProduce output in JSON format only, "
+ "with nothing before or after it."
+ ),
+ "text": PromptTemplate.from_template("{prompt}"), # No customization
+}
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
index de99fc8bd..0c0aeb15f 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
@@ -54,8 +54,8 @@ class CellArgs(BaseModel):
# Should match CellArgs
-class ErrorArgs(BaseModel):
- type: Literal["error"] = "error"
+class FixArgs(BaseModel):
+ type: Literal["fix"] = "fix"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str] = None
@@ -79,13 +79,13 @@ class ListArgs(BaseModel):
class RegisterArgs(BaseModel):
- type: Literal["register"] = "register"
+ type: Literal["alias"] = "alias"
name: str
target: str
class DeleteArgs(BaseModel):
- type: Literal["delete"] = "delete"
+ type: Literal["dealias"] = "dealias"
name: str
@@ -182,7 +182,7 @@ def line_magic_parser():
"""
-@line_magic_parser.command(name="error")
+@line_magic_parser.command(name="fix")
@click.argument("model_id", required=False)
@click.option(
"-f",
@@ -219,14 +219,14 @@ def line_magic_parser():
default="{}",
)
@click.pass_context
-def error_subparser(context: click.Context, **kwargs):
+def fix_subparser(context: click.Context, **kwargs):
"""
- Explains the most recent error. Takes the same options (except -r) as
+ Explains and fixes the most recent error. Takes the same options (except -r) as
the basic `%%ai` command.
"""
if not kwargs["model_id"] and context.default_map:
- kwargs["model_id"] = context.default_map["error_subparser"]["model_id"]
- return ErrorArgs(**kwargs)
+ kwargs["model_id"] = context.default_map["fix_subparser"]["model_id"]
+ return FixArgs(**kwargs)
@line_magic_parser.command(name="version")
@@ -248,13 +248,16 @@ def help_subparser():
)
@click.argument("provider_id", required=False)
def list_subparser(**kwargs):
- """List language models, optionally scoped to PROVIDER_ID."""
+ """List language models, optionally scoped to PROVIDER_ID.\n\n
+ If no PROVIDER_ID is given, all providers are listed.
+ If PROVIDER_ID is given, only models from that provider are listed.
+ If PROVIDER_ID is 'all', models from all providers are listed."""
return ListArgs(**kwargs)
@line_magic_parser.command(
- name="register",
- short_help="Register a new alias. See `%ai register --help` for options.",
+ name="alias",
+ short_help="Register a new alias. See `%ai alias --help` for options.",
)
@click.argument("name")
@click.argument("target")
@@ -264,7 +267,7 @@ def register_subparser(**kwargs):
@line_magic_parser.command(
- name="delete", short_help="Delete an alias. See `%ai delete --help` for options."
+ name="dealias", short_help="Delete an alias. See `%ai dealias --help` for options."
)
@click.argument("name")
def register_subparser(**kwargs):
@@ -272,10 +275,6 @@ def register_subparser(**kwargs):
return DeleteArgs(**kwargs)
-@line_magic_parser.command(
- name="update",
- short_help="Update the target of an alias. See `%ai update --help` for options.",
-)
@click.argument("name")
@click.argument("target")
def register_subparser(**kwargs):
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py
deleted file mode 100644
index ad0f18c13..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from langchain_anthropic import ChatAnthropic
-
-from ..base_provider import BaseProvider, EnvAuthStrategy
-
-
-class ChatAnthropicProvider(
- BaseProvider, ChatAnthropic
-): # https://docs.anthropic.com/en/docs/about-claude/models
- id = "anthropic-chat"
- name = "Anthropic"
- models = [
- "claude-2.0",
- "claude-2.1",
- "claude-3-opus-20240229",
- "claude-3-sonnet-20240229",
- "claude-3-haiku-20240307",
- "claude-3-5-haiku-20241022",
- "claude-3-5-sonnet-20240620",
- "claude-3-5-sonnet-20241022",
- ]
- model_id_key = "model"
- pypi_package_deps = ["anthropic"]
- auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")
-
- @property
- def allows_concurrency(self):
- return False
-
- @classmethod
- def is_api_key_exc(cls, e: Exception):
- """
- Determine if the exception is an Anthropic API key error.
- """
- import anthropic
-
- if isinstance(e, anthropic.AuthenticationError):
- return e.status_code == 401 and "Invalid API Key" in str(e)
- return False
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py
deleted file mode 100644
index cc6552365..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py
+++ /dev/null
@@ -1,218 +0,0 @@
-import copy
-import json
-from collections.abc import Coroutine
-from typing import Any
-
-from jsonpath_ng import parse
-from langchain_aws import BedrockEmbeddings, BedrockLLM, ChatBedrock, SagemakerEndpoint
-from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler
-from langchain_core.outputs import LLMResult
-
-from ..base_provider import AwsAuthStrategy, BaseProvider, MultilineTextField, TextField
-from ..embedding_providers import BaseEmbeddingsProvider
-
-
-# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
-class BedrockProvider(BaseProvider, BedrockLLM):
- id = "bedrock"
- name = "Amazon Bedrock"
- models = [
- "amazon.titan-text-express-v1",
- "amazon.titan-text-lite-v1",
- "amazon.titan-text-premier-v1:0",
- "ai21.j2-ultra-v1",
- "ai21.j2-mid-v1",
- "ai21.jamba-instruct-v1:0",
- "cohere.command-light-text-v14",
- "cohere.command-text-v14",
- "cohere.command-r-v1:0",
- "cohere.command-r-plus-v1:0",
- "meta.llama2-13b-chat-v1",
- "meta.llama2-70b-chat-v1",
- "meta.llama3-8b-instruct-v1:0",
- "meta.llama3-70b-instruct-v1:0",
- "meta.llama3-1-8b-instruct-v1:0",
- "meta.llama3-1-70b-instruct-v1:0",
- "meta.llama3-1-405b-instruct-v1:0",
- "mistral.mistral-7b-instruct-v0:2",
- "mistral.mixtral-8x7b-instruct-v0:1",
- "mistral.mistral-large-2402-v1:0",
- "mistral.mistral-large-2407-v1:0",
- ]
- model_id_key = "model_id"
- pypi_package_deps = ["langchain-aws"]
- auth_strategy = AwsAuthStrategy()
- fields = [
- TextField(
- key="credentials_profile_name",
- label="AWS profile (optional)",
- format="text",
- ),
- TextField(key="region_name", label="Region name (optional)", format="text"),
- ]
-
- async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
- return await self._call_in_executor(*args, **kwargs)
-
-
-# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
-class BedrockChatProvider(BaseProvider, ChatBedrock):
- id = "bedrock-chat"
- name = "Amazon Bedrock Chat"
- models = [
- "amazon.titan-text-express-v1",
- "amazon.titan-text-lite-v1",
- "amazon.titan-text-premier-v1:0",
- "anthropic.claude-v2",
- "anthropic.claude-v2:1",
- "anthropic.claude-instant-v1",
- "anthropic.claude-3-sonnet-20240229-v1:0",
- "anthropic.claude-3-haiku-20240307-v1:0",
- "anthropic.claude-3-opus-20240229-v1:0",
- "anthropic.claude-3-5-haiku-20241022-v1:0",
- "anthropic.claude-3-5-sonnet-20240620-v1:0",
- "anthropic.claude-3-5-sonnet-20241022-v2:0",
- "meta.llama2-13b-chat-v1",
- "meta.llama2-70b-chat-v1",
- "meta.llama3-8b-instruct-v1:0",
- "meta.llama3-70b-instruct-v1:0",
- "meta.llama3-1-8b-instruct-v1:0",
- "meta.llama3-1-70b-instruct-v1:0",
- "meta.llama3-1-405b-instruct-v1:0",
- "mistral.mistral-7b-instruct-v0:2",
- "mistral.mixtral-8x7b-instruct-v0:1",
- "mistral.mistral-large-2402-v1:0",
- "mistral.mistral-large-2407-v1:0",
- ]
- model_id_key = "model_id"
- pypi_package_deps = ["langchain-aws"]
- auth_strategy = AwsAuthStrategy()
- fields = [
- TextField(
- key="credentials_profile_name",
- label="AWS profile (optional)",
- format="text",
- ),
- TextField(key="region_name", label="Region name (optional)", format="text"),
- ]
-
- async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
- return await self._call_in_executor(*args, **kwargs)
-
- async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
- return await self._generate_in_executor(*args, **kwargs)
-
- @property
- def allows_concurrency(self):
- return not "anthropic" in self.model_id
-
-
-class BedrockCustomProvider(BaseProvider, ChatBedrock):
- id = "bedrock-custom"
- name = "Amazon Bedrock (custom/provisioned)"
- models = ["*"]
- model_id_key = "model_id"
- model_id_label = "Model ID"
- pypi_package_deps = ["langchain-aws"]
- auth_strategy = AwsAuthStrategy()
- fields = [
- TextField(key="provider", label="Provider (required)", format="text"),
- TextField(key="region_name", label="Region name (optional)", format="text"),
- TextField(
- key="credentials_profile_name",
- label="AWS profile (optional)",
- format="text",
- ),
- ]
- help = (
- "- For Cross-Region Inference use the appropriate `Inference profile ID` (Model ID with a region prefix, e.g., `us.meta.llama3-2-11b-instruct-v1:0`). See the [inference profiles documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). \n"
- "- For custom/provisioned models, specify the model ARN (Amazon Resource Name) as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n"
- "The model provider must also be specified below. This is the provider of your foundation model *in lowercase*, e.g., `amazon`, `anthropic`, `cohere`, `meta`, or `mistral`."
- )
- registry = True
-
-
-# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
-class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
- id = "bedrock"
- name = "Bedrock"
- models = [
- "amazon.titan-embed-text-v1",
- "amazon.titan-embed-text-v2:0",
- "cohere.embed-english-v3",
- "cohere.embed-multilingual-v3",
- ]
- model_id_key = "model_id"
- pypi_package_deps = ["langchain-aws"]
- auth_strategy = AwsAuthStrategy()
-
-
-class JsonContentHandler(LLMContentHandler):
- content_type = "application/json"
- accepts = "application/json"
-
- def __init__(self, request_schema, response_path):
- self.request_schema = json.loads(request_schema)
- self.response_path = response_path
- self.response_parser = parse(response_path)
-
- def replace_values(self, old_val, new_val, d: dict[str, Any]):
- """Replaces values of a dictionary recursively."""
- for key, val in d.items():
- if val == old_val:
- d[key] = new_val
- if isinstance(val, dict):
- self.replace_values(old_val, new_val, val)
-
- return d
-
- def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
- request_obj = copy.deepcopy(self.request_schema)
- self.replace_values("", prompt, request_obj)
- request = json.dumps(request_obj).encode("utf-8")
- return request
-
- def transform_output(self, output: bytes) -> str:
- response_json = json.loads(output.read().decode("utf-8"))
- matches = self.response_parser.find(response_json)
- return matches[0].value
-
-
-class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
- id = "sagemaker-endpoint"
- name = "SageMaker endpoint"
- models = ["*"]
- model_id_key = "endpoint_name"
- model_id_label = "Endpoint name"
- # This all needs to be on one line of markdown, for use in a table
- help = (
- "Specify an endpoint name as the model ID. "
- "In addition, you must specify a region name, request schema, and response path. "
- "For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deploy-models.html) "
- "and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)."
- )
-
- pypi_package_deps = ["langchain-aws"]
- auth_strategy = AwsAuthStrategy()
- registry = True
- fields = [
- TextField(key="region_name", label="Region name (required)", format="text"),
- MultilineTextField(
- key="request_schema", label="Request schema (required)", format="json"
- ),
- TextField(
- key="response_path", label="Response path (required)", format="jsonpath"
- ),
- ]
-
- def __init__(self, *args, **kwargs):
- request_schema = kwargs.pop("request_schema")
- response_path = kwargs.pop("response_path")
- content_handler = JsonContentHandler(
- request_schema=request_schema, response_path=response_path
- )
-
- super().__init__(*args, **kwargs, content_handler=content_handler)
-
- async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
- return await self._call_in_executor(*args, **kwargs)
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/cohere.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/cohere.py
deleted file mode 100644
index 9ebc287a2..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/cohere.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from langchain_cohere import ChatCohere, CohereEmbeddings
-
-from ..base_provider import BaseProvider, EnvAuthStrategy
-from ..embedding_providers import BaseEmbeddingsProvider
-
-
-class CohereProvider(BaseProvider, ChatCohere):
- id = "cohere"
- name = "Cohere"
- # https://docs.cohere.com/docs/models
- # note: This provider uses the Chat endpoint instead of the Generate
- # endpoint, which is now deprecated.
- models = [
- "command",
- "command-nightly",
- "command-light",
- "command-light-nightly",
- "command-r-plus",
- "command-r",
- ]
- model_id_key = "model"
- pypi_package_deps = ["langchain_cohere"]
- auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")
-
-
-class CohereEmbeddingsProvider(BaseEmbeddingsProvider, CohereEmbeddings):
- id = "cohere"
- name = "Cohere"
- models = [
- "embed-english-v2.0",
- "embed-english-light-v2.0",
- "embed-multilingual-v2.0",
- "embed-english-v3.0",
- "embed-english-light-v3.0",
- "embed-multilingual-v3.0",
- "embed-multilingual-light-v3.0",
- ]
- model_id_key = "model"
- pypi_package_deps = ["langchain_cohere"]
- auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py
deleted file mode 100644
index 0a5a99139..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from jupyter_ai_magics.base_provider import BaseProvider, EnvAuthStrategy
-from langchain_google_genai import GoogleGenerativeAI
-
-
-# See list of model ids here: https://ai.google.dev/gemini-api/docs/models/gemini
-class GeminiProvider(BaseProvider, GoogleGenerativeAI):
- id = "gemini"
- name = "Gemini"
- models = [
- "gemini-2.5-pro",
- "gemini-2.5-flash",
- "gemini-2.0-flash-lite",
- "gemini-1.5-pro",
- "gemini-1.5-flash",
- ]
- model_id_key = "model"
- auth_strategy = EnvAuthStrategy(name="GOOGLE_API_KEY")
- pypi_package_deps = ["langchain-google-genai"]
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py
deleted file mode 100644
index cfc316477..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from jupyter_ai_magics.base_provider import BaseProvider, EnvAuthStrategy
-from langchain_mistralai import ChatMistralAI, MistralAIEmbeddings
-
-from ..embedding_providers import BaseEmbeddingsProvider
-
-
-class MistralAIProvider(BaseProvider, ChatMistralAI):
- id = "mistralai"
- name = "MistralAI"
- models = [
- "open-mistral-7b",
- "open-mixtral-8x7b",
- "open-mixtral-8x22b",
- "mistral-small-latest",
- "mistral-medium-latest",
- "mistral-large-latest",
- "codestral-latest",
- ]
- model_id_key = "model"
- auth_strategy = EnvAuthStrategy(name="MISTRAL_API_KEY")
- pypi_package_deps = ["langchain-mistralai"]
-
-
-class MistralAIEmbeddingsProvider(BaseEmbeddingsProvider, MistralAIEmbeddings):
- id = "mistralai"
- name = "MistralAI"
- models = [
- "mistral-embed",
- ]
- model_id_key = "model"
- pypi_package_deps = ["langchain-mistralai"]
- auth_strategy = EnvAuthStrategy(name="MISTRAL_API_KEY")
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/nvidia.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/nvidia.py
deleted file mode 100644
index 46e92fa06..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/nvidia.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from jupyter_ai_magics.base_provider import BaseProvider, EnvAuthStrategy
-from langchain_nvidia_ai_endpoints import ChatNVIDIA
-
-
-class ChatNVIDIAProvider(BaseProvider, ChatNVIDIA):
- id = "nvidia-chat"
- name = "NVIDIA"
- models = [
- "playground_llama2_70b",
- "playground_nemotron_steerlm_8b",
- "playground_mistral_7b",
- "playground_nv_llama2_rlhf_70b",
- "playground_llama2_13b",
- "playground_steerlm_llama_70b",
- "playground_llama2_code_13b",
- "playground_yi_34b",
- "playground_mixtral_8x7b",
- "playground_neva_22b",
- "playground_llama2_code_34b",
- ]
- model_id_key = "model"
- auth_strategy = EnvAuthStrategy(name="NVIDIA_API_KEY")
- pypi_package_deps = ["langchain_nvidia_ai_endpoints"]
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py
deleted file mode 100644
index 58edd9de3..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from langchain_ollama import ChatOllama, OllamaEmbeddings
-
-from ..base_provider import BaseProvider, TextField
-from ..embedding_providers import BaseEmbeddingsProvider
-
-
-class OllamaProvider(BaseProvider, ChatOllama):
- id = "ollama"
- name = "Ollama"
- model_id_key = "model"
- help = (
- "See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. "
- "Pass a model's name; for example, `deepseek-coder-v2`."
- )
- models = ["*"]
- registry = True
- fields = [
- TextField(key="base_url", label="Base API URL (optional)", format="text"),
- ]
-
-
-class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings):
- id = "ollama"
- name = "Ollama"
- # source: https://ollama.com/library
- model_id_key = "model"
- help = (
- "See [https://ollama.com/search?c=embedding](https://ollama.com/search?c=embedding) for a list of models. "
- "Pass an embedding model's name; for example, `mxbai-embed-large`."
- )
- models = ["*"]
- registry = True
- fields = [
- TextField(key="base_url", label="Base API URL (optional)", format="text"),
- ]
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py
deleted file mode 100644
index db0113c45..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py
+++ /dev/null
@@ -1,170 +0,0 @@
-from langchain_openai import (
- AzureChatOpenAI,
- AzureOpenAIEmbeddings,
- ChatOpenAI,
- OpenAI,
- OpenAIEmbeddings,
-)
-
-from ..base_provider import BaseProvider, EnvAuthStrategy, TextField
-from ..embedding_providers import BaseEmbeddingsProvider
-
-
-class OpenAIProvider(BaseProvider, OpenAI):
- id = "openai"
- name = "OpenAI"
- models = ["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct"]
- model_id_key = "model_name"
- pypi_package_deps = ["langchain_openai"]
- auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
-
- @classmethod
- def is_api_key_exc(cls, e: Exception):
- """
- Determine if the exception is an OpenAI API key error.
- """
- import openai
-
- if isinstance(e, openai.AuthenticationError):
- error_details = e.json_body.get("error", {})
- return error_details.get("code") == "invalid_api_key"
- return False
-
-
-# https://platform.openai.com/docs/models/
-class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
- id = "openai-chat"
- name = "OpenAI"
- models = [
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-1106",
- "gpt-4",
- "gpt-4-turbo",
- "gpt-4-turbo-preview",
- "gpt-4-0613",
- "gpt-4-0125-preview",
- "gpt-4-1106-preview",
- "gpt-4o",
- "gpt-4o-2024-11-20",
- "gpt-4o-mini",
- "chatgpt-4o-latest",
- "gpt-4.1",
- "gpt-4.1-mini",
- "gpt-4.1-nano",
- "o1",
- "o3-mini",
- "o4-mini",
- ]
- model_id_key = "model_name"
- pypi_package_deps = ["langchain_openai"]
- auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
-
- fields = [
- TextField(
- key="openai_api_base", label="Base API URL (optional)", format="text"
- ),
- TextField(
- key="openai_organization", label="Organization (optional)", format="text"
- ),
- TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
- ]
-
- @classmethod
- def is_api_key_exc(cls, e: Exception):
- """
- Determine if the exception is an OpenAI API key error.
- """
- import openai
-
- if isinstance(e, openai.AuthenticationError):
- error_details = e.json_body.get("error", {})
- return error_details.get("code") == "invalid_api_key"
- return False
-
-
-class ChatOpenAICustomProvider(BaseProvider, ChatOpenAI):
- id = "openai-chat-custom"
- name = "OpenAI (general interface)"
- models = ["*"]
- model_id_key = "model_name"
- model_id_label = "Model ID"
- pypi_package_deps = ["langchain_openai"]
- auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
- fields = [
- TextField(
- key="openai_api_base", label="Base API URL (optional)", format="text"
- ),
- TextField(
- key="openai_organization", label="Organization (optional)", format="text"
- ),
- TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
- ]
- help = "Supports non-OpenAI models that use the OpenAI API interface. Replace the OpenAI API key with the API key for the chosen provider."
- registry = True
-
-
-class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
- id = "azure-chat-openai"
- name = "Azure OpenAI"
- models = ["*"]
- model_id_key = "azure_deployment"
- model_id_label = "Deployment name"
- pypi_package_deps = ["langchain_openai"]
- # Confusingly, langchain uses both OPENAI_API_KEY and AZURE_OPENAI_API_KEY for azure
- # https://github.com/langchain-ai/langchain/blob/f2579096993ae460516a0aae1d3e09f3eb5c1772/libs/partners/openai/langchain_openai/llms/azure.py#L85
- auth_strategy = EnvAuthStrategy(
- name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key"
- )
- registry = True
-
- fields = [
- TextField(key="azure_endpoint", label="Base API URL (required)", format="text"),
- TextField(key="api_version", label="API version (required)", format="text"),
- ]
-
-
-class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
- id = "openai"
- name = "OpenAI"
- models = [
- "text-embedding-ada-002",
- "text-embedding-3-small",
- "text-embedding-3-large",
- ]
- model_id_key = "model"
- pypi_package_deps = ["langchain_openai"]
- auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
-
-
-class OpenAIEmbeddingsCustomProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
- id = "openai-custom"
- name = "OpenAI (general interface)"
- models = ["*"]
- model_id_key = "model"
- pypi_package_deps = ["langchain_openai"]
- auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
- registry = True
- fields = [
- TextField(
- key="openai_api_base", label="Base API URL (optional)", format="text"
- ),
- ]
- help = "Supports non-OpenAI embedding models that use the OpenAI API interface. Replace the OpenAI API key with the API key for the chosen provider."
-
-
-class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings):
- id = "azure"
- name = "Azure OpenAI"
- models = [
- "text-embedding-ada-002",
- "text-embedding-3-small",
- "text-embedding-3-large",
- ]
- model_id_key = "azure_deployment"
- pypi_package_deps = ["langchain_openai"]
- auth_strategy = EnvAuthStrategy(
- name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key"
- )
- fields = [
- TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"),
- ]
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py
deleted file mode 100644
index ac07ec129..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from jupyter_ai_magics import BaseProvider
-from jupyter_ai_magics.base_provider import EnvAuthStrategy, TextField
-from langchain_core.utils import get_from_dict_or_env
-from langchain_openai import ChatOpenAI
-
-
-class ChatOpenRouter(ChatOpenAI):
- @property
- def lc_secrets(self) -> dict[str, str]:
- return {"openai_api_key": "OPENROUTER_API_KEY"}
-
-
-class OpenRouterProvider(BaseProvider, ChatOpenRouter):
- id = "openrouter"
- name = "OpenRouter"
- models = [
- "*"
- ] # OpenRouter supports multiple models, so we use "*" to indicate it's a registry
- model_id_key = "model_name"
- pypi_package_deps = ["langchain_openai"]
- auth_strategy = EnvAuthStrategy(name="OPENROUTER_API_KEY")
- registry = True
-
- fields = [
- TextField(
- key="openai_api_base", label="API Base URL (optional)", format="text"
- ),
- ]
-
- def __init__(self, **kwargs):
- openrouter_api_key = get_from_dict_or_env(
- kwargs, key="openrouter_api_key", env_key="OPENROUTER_API_KEY", default=None
- )
- openrouter_api_base = kwargs.pop(
- "openai_api_base", "https://openrouter.ai/api/v1"
- )
- kwargs.pop("openrouter_api_key", None)
- super().__init__(
- openai_api_key=openrouter_api_key,
- openai_api_base=openrouter_api_base,
- **kwargs,
- )
-
- @classmethod
- def is_api_key_exc(cls, e: Exception):
- import openai
-
- if isinstance(e, openai.AuthenticationError):
- error_details = e.json_body.get("error", {})
- return error_details.get("code") == "invalid_api_key"
- return False
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/vertexai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/vertexai.py
deleted file mode 100644
index d22210a68..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/vertexai.py
+++ /dev/null
@@ -1,39 +0,0 @@
-from jupyter_ai_magics.base_provider import BaseProvider
-from langchain_google_vertexai import VertexAI, VertexAIEmbeddings
-
-
-class VertexAIProvider(BaseProvider, VertexAI):
- id = "vertexai"
- name = "Vertex AI"
- models = [
- "gemini-2.5-pro",
- "gemini-2.5-flash",
- ]
- model_id_key = "model"
- auth_strategy = None
- pypi_package_deps = ["langchain-google-vertexai"]
- help = (
- "To use Vertex AI Generative AI you must have the langchain-google-vertexai Python package installed and either:\n\n"
- "- Have credentials configured for your environment (gcloud, workload identity, etc...)\n"
- "- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n\n"
- "This codebase uses the google.auth library which first looks for the application credentials variable mentioned above, and then looks for system-level auth. "
- "For more information, see the [Vertex AI authentication documentation](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm/)."
- )
-
-
-class VertexAIEmbeddingsProvider(BaseProvider, VertexAIEmbeddings):
- id = "vertexai"
- name = "Vertex AI"
- models = [
- "text-embedding-004",
- ]
- model_id_key = "model"
- auth_strategy = None
- pypi_package_deps = ["langchain-google-vertexai"]
- help = (
- "To use Vertex AI Generative AI you must have the langchain-google-vertexai Python package installed and either:\n\n"
- "- Have credentials configured for your environment (gcloud, workload identity, etc...)\n"
- "- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n\n"
- "This codebase uses the google.auth library which first looks for the application credentials variable mentioned above, and then looks for system-level auth. "
- "For more information, see the [Vertex AI authentication documentation](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm/)."
- )
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
deleted file mode 100644
index cb03951fa..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
+++ /dev/null
@@ -1,247 +0,0 @@
-import base64
-import io
-import json
-from collections.abc import Coroutine
-from typing import (
- Any,
- Optional,
-)
-
-from langchain.prompts import (
- PromptTemplate,
-)
-from langchain_community.chat_models import QianfanChatEndpoint
-from langchain_community.llms import AI21, GPT4All, HuggingFaceEndpoint, Together
-
-# for backward compatibility the import group below includes
-# imports of objects which used to be defined here but were
-# only used by BaseProvider that has now moved to `.base_provider`.
-from .base_provider import (
- CHAT_DEFAULT_TEMPLATE,
- CHAT_SYSTEM_PROMPT,
- COMPLETION_DEFAULT_TEMPLATE,
- COMPLETION_SYSTEM_PROMPT,
- HUMAN_MESSAGE_TEMPLATE,
- AuthStrategy,
- AwsAuthStrategy,
- BaseProvider,
- EnvAuthStrategy,
- Field,
- IntegerField,
- MultiEnvAuthStrategy,
- MultilineTextField,
- TextField,
-)
-
-
-class AI21Provider(BaseProvider, AI21):
- id = "ai21"
- name = "AI21"
- models = [
- "j1-large",
- "j1-grande",
- "j1-jumbo",
- "j1-grande-instruct",
- "j2-large",
- "j2-grande",
- "j2-jumbo",
- "j2-grande-instruct",
- "j2-jumbo-instruct",
- ]
- model_id_key = "model"
- pypi_package_deps = ["ai21"]
- auth_strategy = EnvAuthStrategy(name="AI21_API_KEY")
-
- async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
- return await self._call_in_executor(*args, **kwargs)
-
- @classmethod
- def is_api_key_exc(cls, e: Exception):
- """
- Determine if the exception is an AI21 API key error.
- """
- if isinstance(e, ValueError):
- return "status code 401" in str(e)
- return False
-
-
-class GPT4AllProvider(BaseProvider, GPT4All):
- def __init__(self, **kwargs):
- model = kwargs.get("model_id")
- if model == "ggml-gpt4all-l13b-snoozy":
- kwargs["backend"] = "llama"
- else:
- kwargs["backend"] = "gptj"
-
- kwargs["allow_download"] = False
- n_threads = kwargs.get("n_threads", None)
- if n_threads is not None:
- kwargs["n_threads"] = max(int(n_threads), 1)
- super().__init__(**kwargs)
-
- id = "gpt4all"
- name = "GPT4All"
- models = [
- "ggml-gpt4all-j-v1.2-jazzy",
- "ggml-gpt4all-j-v1.3-groovy",
- # this one needs llama backend and has licence restriction
- "ggml-gpt4all-l13b-snoozy",
- "mistral-7b-openorca.Q4_0",
- "mistral-7b-instruct-v0.1.Q4_0",
- "gpt4all-falcon-q4_0",
- "wizardlm-13b-v1.2.Q4_0",
- "nous-hermes-llama2-13b.Q4_0",
- "gpt4all-13b-snoozy-q4_0",
- "mpt-7b-chat-merges-q4_0",
- "orca-mini-3b-gguf2-q4_0",
- "starcoder-q4_0",
- "rift-coder-v0-7b-q4_0",
- "em_german_mistral_v01.Q4_0",
- ]
- model_id_key = "model"
- pypi_package_deps = ["gpt4all"]
- auth_strategy = None
- fields = [IntegerField(key="n_threads", label="CPU thread count (optional)")]
-
- async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
- return await self._call_in_executor(*args, **kwargs)
-
- @property
- def allows_concurrency(self):
- # At present, GPT4All providers fail with concurrent messages. See #481.
- return False
-
-
-# References for using HuggingFaceEndpoint and InferenceClient:
-# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient
-# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/huggingface_endpoint.py
-class HfHubProvider(BaseProvider, HuggingFaceEndpoint):
- id = "huggingface_hub"
- name = "Hugging Face Hub"
- models = ["*"]
- model_id_key = "repo_id"
- help = (
- "See [https://huggingface.co/models](https://huggingface.co/models) for a list of models. "
- "Pass a model's repository ID as the model ID; for example, `huggingface_hub:ExampleOwner/example-model`."
- )
- # ipywidgets needed to suppress tqdm warning
- # https://stackoverflow.com/questions/67998191
- # tqdm is a dependency of huggingface_hub
- pypi_package_deps = ["huggingface_hub", "ipywidgets"]
- auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
- registry = True
-
- # Handle text and image outputs
- def _call(
- self, prompt: str, stop: Optional[list[str]] = None, **kwargs: Any
- ) -> str:
- """Call out to Hugging Face Hub's inference endpoint.
-
- Args:
- prompt: The prompt to pass into the model.
- stop: Optional list of stop words to use when generating.
-
- Returns:
- The string or image generated by the model.
-
- Example:
- .. code-block:: python
-
- response = hf("Tell me a joke.")
- """
- invocation_params = self._invocation_params(stop, **kwargs)
- invocation_params["stop"] = invocation_params[
- "stop_sequences"
- ] # porting 'stop_sequences' into the 'stop' argument
- response = self.client.post(
- json={"inputs": prompt, "parameters": invocation_params},
- stream=False,
- task=self.task,
- )
-
- try:
- if "generated_text" in str(response):
- # text2 text or text-generation task
- response_text = json.loads(response.decode())[0]["generated_text"]
- # Maybe the generation has stopped at one of the stop sequences:
- # then we remove this stop sequence from the end of the generated text
- for stop_seq in invocation_params["stop_sequences"]:
- if response_text[-len(stop_seq) :] == stop_seq:
- response_text = response_text[: -len(stop_seq)]
- return response_text
- else:
- # text-to-image task
- # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_to_image.example
- # Custom code for responding to image generation responses
- image = self.client.text_to_image(prompt)
- imageFormat = image.format # Presume it's a PIL ImageFile
- mimeType = ""
- if imageFormat == "JPEG":
- mimeType = "image/jpeg"
- elif imageFormat == "PNG":
- mimeType = "image/png"
- elif imageFormat == "GIF":
- mimeType = "image/gif"
- else:
- raise ValueError(f"Unrecognized image format {imageFormat}")
- buffer = io.BytesIO()
- image.save(buffer, format=imageFormat)
- # # Encode image data to Base64 bytes, then decode bytes to str
- return (
- mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode()
- )
- except:
- raise ValueError(
- "Task not supported, only text-generation and text-to-image tasks are valid."
- )
-
- async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
- return await self._call_in_executor(*args, **kwargs)
-
-
-class TogetherAIProvider(BaseProvider, Together):
- id = "togetherai"
- name = "Together AI"
- model_id_key = "model"
- models = [
- "Austism/chronos-hermes-13b",
- "DiscoResearch/DiscoLM-mixtral-8x7b-v2",
- "EleutherAI/llemma_7b",
- "Gryphe/MythoMax-L2-13b",
- "Meta-Llama/Llama-Guard-7b",
- "Nexusflow/NexusRaven-V2-13B",
- "NousResearch/Nous-Capybara-7B-V1p9",
- "NousResearch/Nous-Hermes-2-Yi-34B",
- "NousResearch/Nous-Hermes-Llama2-13b",
- "NousResearch/Nous-Hermes-Llama2-70b",
- ]
- pypi_package_deps = ["together"]
- auth_strategy = EnvAuthStrategy(name="TOGETHER_API_KEY")
-
- def __init__(self, **kwargs):
- model = kwargs.get("model_id")
-
- if model not in self.models:
- kwargs["responses"] = [
- "Model not supported! Please check model list with %ai list"
- ]
-
- super().__init__(**kwargs)
-
- def get_prompt_template(self, format) -> PromptTemplate:
- if format == "code":
- return PromptTemplate.from_template(
- "{prompt}\n\nProduce output as source code only, "
- "with no text or explanation before or after it."
- )
- return super().get_prompt_template(format)
-
-
-# Baidu QianfanChat provider. temporarily living as a separate class until
-class QianfanProvider(BaseProvider, QianfanChatEndpoint):
- id = "qianfan"
- name = "ERNIE-Bot"
- models = ["ERNIE-Bot", "ERNIE-Bot-4"]
- model_id_key = "model_name"
- pypi_package_deps = ["qianfan"]
- auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed b/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed
deleted file mode 100644
index e69de29bb..000000000
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py
deleted file mode 100644
index 4029cd5bd..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from typing import ClassVar, Optional
-
-from pydantic import BaseModel
-
-from ..providers import BaseProvider
-
-
-def test_provider_classvars():
- """
- Asserts that class attributes are not omitted due to parent classes defining
- an instance field of the same name. This was a bug present in Pydantic v1,
- which led to an issue documented in #558.
-
- This bug is fixed as of `pydantic==2.10.2`, but we will keep this test in
- case this behavior changes in future releases.
- """
-
- class Parent(BaseModel):
- test: Optional[str] = None
-
- class Base(BaseModel):
- test: ClassVar[str]
-
- class Child(Base, Parent):
- test: ClassVar[str] = "expected"
-
- assert Child.test == "expected"
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_imports.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_imports.py
deleted file mode 100644
index b6ffb4bf5..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_imports.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import ast
-import inspect
-import sys
-
-import jupyter_ai_magics
-
-
-def test_uses_lazy_imports():
- assert "jupyter_ai_magics.exception" not in sys.modules
- jupyter_ai_magics.exception
- assert "jupyter_ai_magics.exception" in sys.modules
-
-
-def test_all_includes_all_dynamic_imports():
- dynamic_imports = set(jupyter_ai_magics._modules_by_export.keys())
- assert dynamic_imports - set(jupyter_ai_magics.__all__) == set()
-
-
-def test_dir_returns_all():
- assert set(jupyter_ai_magics.__all__) == set(dir(jupyter_ai_magics))
-
-
-def test_all_type_checked():
- dynamic_imports = set(jupyter_ai_magics._modules_by_export.keys())
- tree = ast.parse(inspect.getsource(jupyter_ai_magics))
- imports_in_type_checking = {
- alias.asname if alias.asname else alias.name
- for node in ast.walk(tree)
- if isinstance(node, ast.If)
- and isinstance(node.test, ast.Name)
- and node.test.id == "TYPE_CHECKING"
- for stmt in node.body
- if isinstance(stmt, ast.ImportFrom)
- for alias in stmt.names
- }
-
- assert imports_in_type_checking == set(dynamic_imports)
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py
index eb087a376..a393029d6 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py
@@ -1,11 +1,7 @@
-import os
-from unittest.mock import Mock, patch
+from unittest.mock import patch
-import pytest
from IPython import InteractiveShell
-from IPython.core.display import Markdown
from jupyter_ai_magics.magics import AiMagics
-from langchain_core.messages import AIMessage, HumanMessage
from pytest import fixture
from traitlets.config.loader import Config
@@ -18,14 +14,14 @@ def ip() -> InteractiveShell:
def test_aliases_config(ip):
- ip.config.AiMagics.aliases = {"my_custom_alias": "my_provider:my_model"}
+ ip.config.AiMagics.initial_aliases = {"my_custom_alias": "my_provider:my_model"}
ip.extension_manager.load_extension("jupyter_ai_magics")
providers_list = ip.run_line_magic("ai", "list").text
assert "my_custom_alias" in providers_list
def test_default_model_cell(ip):
- ip.config.AiMagics.default_language_model = "my-favourite-llm"
+ ip.config.AiMagics.initial_language_model = "my-favourite-llm"
ip.extension_manager.load_extension("jupyter_ai_magics")
with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run:
ip.run_cell_magic("ai", "", cell="Write code for me please")
@@ -35,7 +31,7 @@ def test_default_model_cell(ip):
def test_non_default_model_cell(ip):
- ip.config.AiMagics.default_language_model = "my-favourite-llm"
+ ip.config.AiMagics.initial_language_model = "my-favourite-llm"
ip.extension_manager.load_extension("jupyter_ai_magics")
with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run:
ip.run_cell_magic("ai", "some-different-llm", cell="Write code for me please")
@@ -45,79 +41,18 @@ def test_non_default_model_cell(ip):
def test_default_model_error_line(ip):
- ip.config.AiMagics.default_language_model = "my-favourite-llm"
+ ip.config.AiMagics.initial_language_model = "my-favourite-llm"
ip.extension_manager.load_extension("jupyter_ai_magics")
- with patch.object(AiMagics, "handle_error", return_value=None) as mock_run:
- ip.run_cell_magic("ai", "error", cell=None)
+ with patch.object(AiMagics, "handle_fix", return_value=None) as mock_run:
+ ip.run_cell_magic("ai", "fix", cell=None)
assert mock_run.called
cell_args = mock_run.call_args.args[0]
assert cell_args.model_id == "my-favourite-llm"
-PROMPT = HumanMessage(
- content=("Write code for me please\n\nProduce output in markdown format only.")
-)
-RESPONSE = AIMessage(content="Leet code")
-AI1 = AIMessage("ai1")
-H1 = HumanMessage("h1")
-AI2 = AIMessage("ai2")
-H2 = HumanMessage("h2")
-AI3 = AIMessage("ai3")
-
-
-@pytest.mark.parametrize(
- ["transcript", "max_history", "expected_context"],
- [
- ([], 3, [PROMPT]),
- ([AI1], 0, [PROMPT]),
- ([AI1], 1, [AI1, PROMPT]),
- ([H1, AI1], 0, [PROMPT]),
- ([H1, AI1], 1, [H1, AI1, PROMPT]),
- ([AI1, H1, AI2], 0, [PROMPT]),
- ([AI1, H1, AI2], 1, [H1, AI2, PROMPT]),
- ([AI1, H1, AI2], 2, [AI1, H1, AI2, PROMPT]),
- ([H1, AI1, H2, AI2], 0, [PROMPT]),
- ([H1, AI1, H2, AI2], 1, [H2, AI2, PROMPT]),
- ([H1, AI1, H2, AI2], 2, [H1, AI1, H2, AI2, PROMPT]),
- ([AI1, H1, AI2, H2, AI3], 0, [PROMPT]),
- ([AI1, H1, AI2, H2, AI3], 1, [H2, AI3, PROMPT]),
- ([AI1, H1, AI2, H2, AI3], 2, [H1, AI2, H2, AI3, PROMPT]),
- ([AI1, H1, AI2, H2, AI3], 3, [AI1, H1, AI2, H2, AI3, PROMPT]),
- ],
-)
-def test_max_history(ip, transcript, max_history, expected_context):
- ip.extension_manager.load_extension("jupyter_ai_magics")
- ai_magics = ip.magics_manager.registry["AiMagics"]
- ai_magics.transcript = transcript.copy()
- ai_magics.max_history = max_history
- provider = ai_magics._get_provider("openrouter")
- with (
- patch.object(provider, "generate") as generate,
- patch.dict(os.environ, OPENROUTER_API_KEY="123"),
- ):
- generate.return_value.generations = [[Mock(text="Leet code")]]
- result = ip.run_cell_magic(
- "ai",
- "openrouter:anthropic/claude-3.5-sonnet",
- cell="Write code for me please",
- )
- provider.generate.assert_called_once_with([expected_context])
- assert isinstance(result, Markdown)
- assert result.data == "Leet code"
- assert result.filename is None
- assert result.metadata == {
- "jupyter_ai": {
- "model_id": "anthropic/claude-3.5-sonnet",
- "provider_id": "openrouter",
- }
- }
- assert result.url is None
- assert ai_magics.transcript == [*transcript, PROMPT, RESPONSE]
-
-
def test_reset(ip):
ip.extension_manager.load_extension("jupyter_ai_magics")
ai_magics = ip.magics_manager.registry["AiMagics"]
- ai_magics.transcript = [AI1, H1, AI2, H2, AI3]
+ ai_magics.transcript = [{"role": "user", "content": "hello"}]
ip.run_line_magic("ai", "reset")
assert ai_magics.transcript == []
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py
deleted file mode 100644
index 860bb02bc..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright (c) Jupyter Development Team.
-# Distributed under the terms of the Modified BSD License.
-
-import pytest
-from jupyter_ai_magics.utils import get_lm_providers
-
-KNOWN_LM_A = "openai"
-KNOWN_LM_B = "huggingface_hub"
-
-
-@pytest.mark.parametrize(
- "restrictions",
- [
- {"allowed_providers": None, "blocked_providers": None},
- {"allowed_providers": None, "blocked_providers": []},
- {"allowed_providers": None, "blocked_providers": [KNOWN_LM_B]},
- {"allowed_providers": [KNOWN_LM_A], "blocked_providers": []},
- {"allowed_providers": [KNOWN_LM_A], "blocked_providers": None},
- ],
-)
-def test_get_lm_providers_not_restricted(restrictions):
- a_not_restricted = get_lm_providers(None, restrictions)
- assert KNOWN_LM_A in a_not_restricted
-
-
-@pytest.mark.parametrize(
- "restrictions",
- [
- {"allowed_providers": [], "blocked_providers": None},
- {"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]},
- {"allowed_providers": None, "blocked_providers": [KNOWN_LM_A]},
- {"allowed_providers": [KNOWN_LM_B], "blocked_providers": []},
- {"allowed_providers": [KNOWN_LM_B], "blocked_providers": None},
- ],
-)
-def test_get_lm_providers_restricted(restrictions):
- a_not_restricted = get_lm_providers(None, restrictions)
- assert KNOWN_LM_A not in a_not_restricted
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
deleted file mode 100644
index 50cd187bc..000000000
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import logging
-from typing import Literal, Optional, Union
-
-from importlib_metadata import entry_points
-from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
-from jupyter_ai_magics.base_provider import BaseProvider
-from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
-
-Logger = Union[logging.Logger, logging.LoggerAdapter]
-LmProvidersDict = dict[str, BaseProvider]
-EmProvidersDict = dict[str, BaseEmbeddingsProvider]
-AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider]
-ProviderDict = dict[str, AnyProvider]
-ProviderRestrictions = dict[
- Literal["allowed_providers", "blocked_providers"], Optional[list[str]]
-]
-
-
-def get_lm_providers(
- log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
-) -> LmProvidersDict:
- if not log:
- log = logging.getLogger()
- log.addHandler(logging.NullHandler())
- if not restrictions:
- restrictions = {"allowed_providers": None, "blocked_providers": None}
- providers = {}
- eps = entry_points()
- provider_ep_group = eps.select(group="jupyter_ai.model_providers")
- for provider_ep in provider_ep_group:
- try:
- provider = provider_ep.load()
- except ImportError as e:
- log.warning(
- f"Unable to load model provider `{provider_ep.name}`. Please install the `{e.name}` package."
- )
- continue
- except Exception as e:
- log.warning(
- f"Unable to load model provider `{provider_ep.name}`", exc_info=e
- )
- continue
-
- if not is_provider_allowed(provider.id, restrictions):
- log.info(f"Skipping blocked provider `{provider.id}`.")
- continue
- providers[provider.id] = provider
- log.info(f"Registered model provider `{provider.id}`.")
-
- return providers
-
-
-def get_em_providers(
- log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
-) -> EmProvidersDict:
- if not log:
- log = logging.getLogger()
- log.addHandler(logging.NullHandler())
- if not restrictions:
- restrictions = {"allowed_providers": None, "blocked_providers": None}
- providers = {}
- eps = entry_points()
- model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers")
- for model_provider_ep in model_provider_eps:
- try:
- provider = model_provider_ep.load()
- except Exception as e:
- log.warning(
- f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`: %s.",
- e,
- )
- continue
- if not is_provider_allowed(provider.id, restrictions):
- log.info(f"Skipping blocked provider `{provider.id}`.")
- continue
- providers[provider.id] = provider
- log.info(f"Registered embeddings model provider `{provider.id}`.")
-
- return providers
-
-
-def decompose_model_id(
- model_id: str, providers: dict[str, BaseProvider]
-) -> tuple[str, str]:
- """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
- if model_id in MODEL_ID_ALIASES:
- model_id = MODEL_ID_ALIASES[model_id]
-
- if ":" not in model_id:
- # case: model ID was not provided with a prefix indicating the provider
- # ID. try to infer the provider ID before returning (None, None).
-
- # naively search through the dictionary and return the first provider
- # that provides a model of the same ID.
- for provider_id, provider in providers.items():
- if model_id in provider.models:
- return (provider_id, model_id)
-
- return (None, None)
-
- provider_id, local_model_id = model_id.split(":", 1)
- return (provider_id, local_model_id)
-
-
-def get_lm_provider(
- model_id: str, lm_providers: LmProvidersDict
-) -> tuple[str, type[BaseProvider]]:
- """Gets a two-tuple (, ) specified by a
- global model ID."""
- return _get_provider(model_id, lm_providers)
-
-
-def get_em_provider(
- model_id: str, em_providers: EmProvidersDict
-) -> tuple[str, type[BaseEmbeddingsProvider]]:
- """Gets a two-tuple (, ) specified by a
- global model ID."""
- return _get_provider(model_id, em_providers)
-
-
-def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool:
- allowed = restrictions["allowed_providers"]
- blocked = restrictions["blocked_providers"]
- if blocked is not None and provider_id in blocked:
- return False
- if allowed is not None and provider_id not in allowed:
- return False
- return True
-
-
-def _get_provider(model_id: str, providers: ProviderDict):
- provider_id, local_model_id = decompose_model_id(model_id, providers)
- provider = providers.get(provider_id, None)
- return local_model_id, provider
diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml
index 7e54f8ee0..b1d7b6c9e 100644
--- a/packages/jupyter-ai-magics/pyproject.toml
+++ b/packages/jupyter-ai-magics/pyproject.toml
@@ -22,15 +22,10 @@ dynamic = ["version", "description", "authors", "urls", "keywords"]
dependencies = [
"ipython",
- "importlib_metadata>=5.2.0",
- "langchain>=0.3.0,<0.4.0",
- "langchain_community>=0.3.0,<0.4.0",
# pydantic <2.10.0 raises a "protected namespaces" error in JAI
# - See: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.protected_namespaces
"pydantic>=2.10.0,<3",
"click>=8.1.0,<9",
- "jsonpath-ng>=1.5.3,<2",
- "langchain-google-vertexai",
]
[project.optional-dependencies]
@@ -45,64 +40,10 @@ dev = [
test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"]
all = [
- "ai21",
- "gpt4all",
- "huggingface_hub",
- "ipywidgets",
- "langchain_anthropic",
- "langchain_aws",
- "langchain_cohere",
- # Pin cohere to <5.16 to prevent langchain_cohere from breaking as it uses ChatResponse removed in cohere 5.16.0
- # TODO: remove cohere pin when langchain_cohere is updated to work with cohere >=5.16
- "cohere<5.16",
- "langchain_google_genai",
- "langchain_mistralai",
- "langchain_nvidia_ai_endpoints",
- "langchain_openai",
- "langchain_ollama",
- "pillow",
+ # Required for using Amazon Bedrock
"boto3",
- "qianfan",
- "together",
- "langchain-google-vertexai",
]
-[project.entry-points."jupyter_ai.model_providers"]
-ai21 = "jupyter_ai_magics:AI21Provider"
-anthropic-chat = "jupyter_ai_magics.partner_providers.anthropic:ChatAnthropicProvider"
-cohere = "jupyter_ai_magics.partner_providers.cohere:CohereProvider"
-gpt4all = "jupyter_ai_magics:GPT4AllProvider"
-huggingface_hub = "jupyter_ai_magics:HfHubProvider"
-ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaProvider"
-openai = "jupyter_ai_magics.partner_providers.openai:OpenAIProvider"
-openai-chat = "jupyter_ai_magics.partner_providers.openai:ChatOpenAIProvider"
-openai-chat-custom = "jupyter_ai_magics.partner_providers.openai:ChatOpenAICustomProvider"
-azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIProvider"
-sagemaker-endpoint = "jupyter_ai_magics.partner_providers.aws:SmEndpointProvider"
-amazon-bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockProvider"
-amazon-bedrock-chat = "jupyter_ai_magics.partner_providers.aws:BedrockChatProvider"
-amazon-bedrock-custom = "jupyter_ai_magics.partner_providers.aws:BedrockCustomProvider"
-qianfan = "jupyter_ai_magics:QianfanProvider"
-nvidia-chat = "jupyter_ai_magics.partner_providers.nvidia:ChatNVIDIAProvider"
-together-ai = "jupyter_ai_magics:TogetherAIProvider"
-gemini = "jupyter_ai_magics.partner_providers.gemini:GeminiProvider"
-mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIProvider"
-openrouter = "jupyter_ai_magics.partner_providers.openrouter:OpenRouterProvider"
-vertexai = "jupyter_ai_magics.partner_providers.vertexai:VertexAIProvider"
-
-[project.entry-points."jupyter_ai.embeddings_model_providers"]
-azure = "jupyter_ai_magics.partner_providers.openai:AzureOpenAIEmbeddingsProvider"
-bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockEmbeddingsProvider"
-cohere = "jupyter_ai_magics.partner_providers.cohere:CohereEmbeddingsProvider"
-mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIEmbeddingsProvider"
-gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
-huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
-ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaEmbeddingsProvider"
-openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider"
-openai-custom = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsCustomProvider"
-qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"
-vertexai = "jupyter_ai_magics.partner_providers.vertexai:VertexAIEmbeddingsProvider"
-
[tool.hatch.version]
source = "nodejs"
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/pyproject.toml b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/pyproject.toml
index d01a93dda..7480cee63 100644
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/pyproject.toml
+++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/pyproject.toml
@@ -29,9 +29,6 @@ dependencies = ["jupyter_ai"]
[project.optional-dependencies]
test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"]
-[project.entry-points."jupyter_ai.model_providers"]
-test-provider = "{{ cookiecutter.python_name }}.provider:TestProvider"
-
[tool.hatch.build.hooks.version]
path = "{{ cookiecutter.python_name }}/_version.py"
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/llm.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/llm.py
deleted file mode 100644
index fc5134b35..000000000
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/llm.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Any, Optional
-
-from langchain_core.callbacks.manager import CallbackManagerForLLMRun
-from langchain_core.language_models.llms import LLM
-
-
-class TestLLM(LLM):
- model_id: str
-
- @property
- def _llm_type(self) -> str:
- return "custom"
-
- def _call(
- self,
- prompt: str,
- stop: Optional[list[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> str:
- return f"Hello! I am a model for testing only. Model ID: {self.model_id}"
diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/provider.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/provider.py
deleted file mode 100644
index 694049f6b..000000000
--- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/provider.py
+++ /dev/null
@@ -1,70 +0,0 @@
-from typing import ClassVar
-
-from jupyter_ai import AuthStrategy, BaseProvider, Field
-
-from .llm import TestLLM
-
-
-class TestProvider(BaseProvider, TestLLM):
- """
- A test model provider implementation for developers to build from. A model
- provider inherits from 2 classes: 1) the `BaseProvider` class from
- `jupyter_ai`, and 2) an LLM class from `langchain`, i.e. a class inheriting
- from `LLM` or `BaseChatModel`.
-
- Any custom model first requires a `langchain` LLM class implementation.
- Please import one from `langchain`, or refer to the `langchain` docs for
- instructions on how to write your own. We offer an example in `./llm.py` for
- testing.
-
- To create a custom model provider from an existing `langchain`
- implementation, developers should edit this class' declaration to
-
- ```
- class TestModelProvider(BaseProvider, ):
- ...
- ```
-
- Developers should fill in each of the below required class attributes.
- As the implementation is provided by the inherited LLM class, developers
- generally don't need to implement any methods. See the built-in
- implementations in `jupyter_ai_magics.providers.py` for further reference.
-
- The provider is made available to Jupyter AI by the entry point declared in
- `pyproject.toml`. If this class or parent module is renamed, make sure the
- update the entry point there as well.
- """
-
- id: ClassVar[str] = "test-provider"
- """ID for this provider class."""
-
- name: ClassVar[str] = "Test Provider"
- """User-facing name of this provider."""
-
- models: ClassVar[list[str]] = ["test-model-1"]
- """List of supported models by their IDs. For registry providers, this will
- be just ["*"]."""
-
- help: ClassVar[str] = None
- """Text to display in lieu of a model list for a registry provider that does
- not provide a list of models."""
-
- model_id_key: ClassVar[str] = "model_id"
- """Kwarg expected by the upstream LangChain provider."""
-
- model_id_label: ClassVar[str] = "Model ID"
- """Human-readable label of the model ID."""
-
- pypi_package_deps: ClassVar[list[str]] = []
- """List of PyPi package dependencies."""
-
- auth_strategy: ClassVar[AuthStrategy] = None
- """Authentication/authorization strategy. Declares what credentials are
- required to use this model provider. Generally should not be `None`."""
-
- registry: ClassVar[bool] = False
- """Whether this provider is a registry provider."""
-
- fields: ClassVar[list[Field]] = []
- """User inputs expected by this provider when initializing it. Each `Field` `f`
- should be passed in the constructor as a keyword argument, keyed by `f.key`."""
diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py
deleted file mode 100644
index b5ec80dc7..000000000
--- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import time
-from collections.abc import Iterator
-from typing import Any, Optional
-
-from langchain_core.callbacks.manager import CallbackManagerForLLMRun
-from langchain_core.language_models.llms import LLM
-from langchain_core.outputs.generation import GenerationChunk
-
-
-class TestLLM(LLM):
- model_id: str = "test"
-
- @property
- def _llm_type(self) -> str:
- return "custom"
-
- def _call(
- self,
- prompt: str,
- stop: Optional[list[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> str:
- time.sleep(3)
- return f"Hello! This is a dummy response from a test LLM."
-
-
-class TestLLMWithStreaming(LLM):
- model_id: str = "test"
-
- @property
- def _llm_type(self) -> str:
- return "custom"
-
- def _call(
- self,
- prompt: str,
- stop: Optional[list[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> str:
- time.sleep(3)
- return f"Hello! This is a dummy response from a test LLM."
-
- def _stream(
- self,
- prompt: str,
- stop: Optional[list[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> Iterator[GenerationChunk]:
- time.sleep(1)
- yield GenerationChunk(
- text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n",
- generation_info={"test_metadata_field": "foobar"},
- )
- for i in range(1, 6):
- time.sleep(0.2)
- yield GenerationChunk(text=f"{i}, ")
diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py b/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py
deleted file mode 100644
index 4a15dd31e..000000000
--- a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from typing import ClassVar
-
-from jupyter_ai import AuthStrategy, BaseProvider, Field
-
-from .test_llms import TestLLM, TestLLMWithStreaming
-
-
-class TestProvider(BaseProvider, TestLLM):
- id: ClassVar[str] = "test-provider"
- """ID for this provider class."""
-
- name: ClassVar[str] = "Test Provider"
- """User-facing name of this provider."""
-
- models: ClassVar[list[str]] = ["test"]
- """List of supported models by their IDs. For registry providers, this will
- be just ["*"]."""
-
- help: ClassVar[str] = None
- """Text to display in lieu of a model list for a registry provider that does
- not provide a list of models."""
-
- model_id_key: ClassVar[str] = "model_id"
- """Kwarg expected by the upstream LangChain provider."""
-
- model_id_label: ClassVar[str] = "Model ID"
- """Human-readable label of the model ID."""
-
- pypi_package_deps: ClassVar[list[str]] = []
- """List of PyPi package dependencies."""
-
- auth_strategy: ClassVar[AuthStrategy] = None
- """Authentication/authorization strategy. Declares what credentials are
- required to use this model provider. Generally should not be `None`."""
-
- registry: ClassVar[bool] = False
- """Whether this provider is a registry provider."""
-
- fields: ClassVar[list[Field]] = []
- """User inputs expected by this provider when initializing it. Each `Field` `f`
- should be passed in the constructor as a keyword argument, keyed by `f.key`."""
-
-
-class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming):
- id: ClassVar[str] = "test-provider-with-streaming"
- """ID for this provider class."""
-
- name: ClassVar[str] = "Test Provider (streaming)"
- """User-facing name of this provider."""
-
- models: ClassVar[list[str]] = ["test"]
- """List of supported models by their IDs. For registry providers, this will
- be just ["*"]."""
-
- help: ClassVar[str] = None
- """Text to display in lieu of a model list for a registry provider that does
- not provide a list of models."""
-
- model_id_key: ClassVar[str] = "model_id"
- """Kwarg expected by the upstream LangChain provider."""
-
- model_id_label: ClassVar[str] = "Model ID"
- """Human-readable label of the model ID."""
-
- pypi_package_deps: ClassVar[list[str]] = []
- """List of PyPi package dependencies."""
-
- auth_strategy: ClassVar[AuthStrategy] = None
- """Authentication/authorization strategy. Declares what credentials are
- required to use this model provider. Generally should not be `None`."""
-
- registry: ClassVar[bool] = False
- """Whether this provider is a registry provider."""
-
- fields: ClassVar[list[Field]] = []
- """User inputs expected by this provider when initializing it. Each `Field` `f`
- should be passed in the constructor as a keyword argument, keyed by `f.key`."""
-
-
-class TestProviderAskLearnUnsupported(BaseProvider, TestLLMWithStreaming):
- id: ClassVar[str] = "test-provider-ask-learn-unsupported"
- """ID for this provider class."""
-
- name: ClassVar[str] = "Test Provider (/learn and /ask unsupported)"
- """User-facing name of this provider."""
-
- models: ClassVar[list[str]] = ["test"]
- """List of supported models by their IDs. For registry providers, this will
- be just ["*"]."""
-
- help: ClassVar[str] = None
- """Text to display in lieu of a model list for a registry provider that does
- not provide a list of models."""
-
- model_id_key: ClassVar[str] = "model_id"
- """Kwarg expected by the upstream LangChain provider."""
-
- model_id_label: ClassVar[str] = "Model ID"
- """Human-readable label of the model ID."""
-
- pypi_package_deps: ClassVar[list[str]] = []
- """List of PyPi package dependencies."""
-
- auth_strategy: ClassVar[AuthStrategy] = None
- """Authentication/authorization strategy. Declares what credentials are
- required to use this model provider. Generally should not be `None`."""
-
- registry: ClassVar[bool] = False
- """Whether this provider is a registry provider."""
-
- fields: ClassVar[list[Field]] = []
- """User inputs expected by this provider when initializing it. Each `Field` `f`
- should be passed in the constructor as a keyword argument, keyed by `f.key`."""
-
- unsupported_slash_commands = {"/learn", "/ask"}
diff --git a/packages/jupyter-ai-test/pyproject.toml b/packages/jupyter-ai-test/pyproject.toml
index 0757aa7d9..992022662 100644
--- a/packages/jupyter-ai-test/pyproject.toml
+++ b/packages/jupyter-ai-test/pyproject.toml
@@ -27,11 +27,6 @@ dependencies = ["jupyter_ai"]
[project.optional-dependencies]
test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"]
-[project.entry-points."jupyter_ai.model_providers"]
-test-provider = "jupyter_ai_test.test_providers:TestProvider"
-test-provider-with-streaming = "jupyter_ai_test.test_providers:TestProviderWithStreaming"
-test-provider-ask-learn-unsupported = "jupyter_ai_test.test_providers:TestProviderAskLearnUnsupported"
-
[project.entry-points."jupyter_ai.personas"]
debug-persona = "jupyter_ai_test.debug_persona:DebugPersona"
diff --git a/packages/jupyter-ai/jupyter_ai/__init__.py b/packages/jupyter-ai/jupyter_ai/__init__.py
index d1d6e4d37..830a04627 100644
--- a/packages/jupyter-ai/jupyter_ai/__init__.py
+++ b/packages/jupyter-ai/jupyter_ai/__init__.py
@@ -5,11 +5,10 @@
# expose jupyter_ai_magics ipython extension
# DO NOT REMOVE.
-from jupyter_ai_magics import load_ipython_extension, unload_ipython_extension
-
-# expose jupyter_ai_magics providers
-# DO NOT REMOVE.
-from jupyter_ai_magics.providers import *
+from jupyter_ai_magics import ( # type: ignore[import-untyped]
+ load_ipython_extension,
+ unload_ipython_extension,
+)
from ._version import __version__
from .extension import AiExtension
diff --git a/packages/jupyter-ai/jupyter_ai/completions/completion_prompts.py b/packages/jupyter-ai/jupyter_ai/completions/completion_prompts.py
new file mode 100644
index 000000000..10953a4e8
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/completions/completion_prompts.py
@@ -0,0 +1,24 @@
+COMPLETION_SYSTEM_PROMPT = """
+You are an application built to provide helpful code completion suggestions.
+You should only produce code. Keep comments to minimum, use the
+programming language comment syntax. Produce clean code.
+The code is written in JupyterLab, a data analysis and code development
+environment which can execute code extended with additional syntax for
+interactive features, such as magics.
+""".strip()
+
+# only add the suffix bit if present to save input tokens/computation time
+COMPLETION_DEFAULT_TEMPLATE = """
+The document is called `{{filename}}` and written in {{language}}.
+{% if suffix %}
+The code after the completion request is:
+
+```
+{{suffix}}
+```
+{% endif %}
+
+Complete the following code:
+
+```
+{{prefix}}"""
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai/jupyter_ai/completions/completion_types.py
similarity index 100%
rename from packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py
rename to packages/jupyter-ai/jupyter_ai/completions/completion_types.py
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py b/packages/jupyter-ai/jupyter_ai/completions/completion_utils.py
similarity index 97%
rename from packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py
rename to packages/jupyter-ai/jupyter_ai/completions/completion_utils.py
index 1330e8992..150ac40c0 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py
+++ b/packages/jupyter-ai/jupyter_ai/completions/completion_utils.py
@@ -1,4 +1,4 @@
-from .models.completion import InlineCompletionRequest
+from .completion_types import InlineCompletionRequest
def token_from_request(request: InlineCompletionRequest, suggestion: int):
diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
index cf07550bc..fd83f43d8 100644
--- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
+++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
@@ -5,21 +5,19 @@
from typing import Union
import tornado
-from jupyter_ai.completions.handlers.model_mixin import CompletionsModelMixin
-from jupyter_ai.completions.models import (
+from jupyter_server.base.handlers import JupyterHandler
+from pydantic import ValidationError
+
+from ..completion_types import (
CompletionError,
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
-from jupyter_server.base.handlers import JupyterHandler
-from pydantic import ValidationError
-class BaseInlineCompletionHandler(
- CompletionsModelMixin, JupyterHandler, tornado.websocket.WebSocketHandler
-):
+class BaseInlineCompletionHandler(JupyterHandler, tornado.websocket.WebSocketHandler):
"""A Tornado WebSocket handler that receives inline completion requests and
fulfills them accordingly. This class is instantiated once per WebSocket
connection."""
diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
index 38676b998..93c496a8e 100644
--- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
+++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
@@ -1,4 +1,4 @@
-from ..models import InlineCompletionRequest
+from ..completion_types import InlineCompletionRequest
from .base import BaseInlineCompletionHandler
@@ -8,17 +8,79 @@ def __init__(self, *args, **kwargs):
async def handle_request(self, request: InlineCompletionRequest):
"""Handles an inline completion request without streaming."""
- llm = self.get_llm()
- if not llm:
- raise ValueError("Please select a model for inline completion.")
+ # TODO: migrate this to use LiteLLM
+ # llm = self.get_llm()
+ # if not llm:
+ # raise ValueError("Please select a model for inline completion.")
- reply = await llm.generate_inline_completions(request)
- self.reply(reply)
+ # reply = await llm.generate_inline_completions(request)
+ # self.reply(reply)
async def handle_stream_request(self, request: InlineCompletionRequest):
- llm = self.get_llm()
- if not llm:
- raise ValueError("Please select a model for inline completion.")
+ # TODO: migrate this to use LiteLLM
+ # llm = self.get_llm()
+ # if not llm:
+ # raise ValueError("Please select a model for inline completion.")
- async for reply in llm.stream_inline_completions(request):
- self.reply(reply)
+ # async for reply in llm.stream_inline_completions(request):
+ # self.reply(reply)
+ pass
+
+
+# old methods on BaseProvider, for reference when migrating this to LiteLLM
+#
+# async def generate_inline_completions(
+# self, request: InlineCompletionRequest
+# ) -> InlineCompletionReply:
+# chain = self._create_completion_chain()
+# model_arguments = completion.template_inputs_from_request(request)
+# suggestion = await chain.ainvoke(input=model_arguments)
+# suggestion = completion.post_process_suggestion(suggestion, request)
+# return InlineCompletionReply(
+# list=InlineCompletionList(items=[{"insertText": suggestion}]),
+# reply_to=request.number,
+# )
+
+# async def stream_inline_completions(
+# self, request: InlineCompletionRequest
+# ) -> AsyncIterator[InlineCompletionStreamChunk]:
+# chain = self._create_completion_chain()
+# token = completion.token_from_request(request, 0)
+# model_arguments = completion.template_inputs_from_request(request)
+# suggestion = processed_suggestion = ""
+
+# # send an incomplete `InlineCompletionReply`, indicating to the
+# # client that LLM output is about to streamed across this connection.
+# yield InlineCompletionReply(
+# list=InlineCompletionList(
+# items=[
+# {
+# # insert text starts empty as we do not pre-generate any part
+# "insertText": "",
+# "isIncomplete": True,
+# "token": token,
+# }
+# ]
+# ),
+# reply_to=request.number,
+# )
+
+# async for fragment in chain.astream(input=model_arguments):
+# suggestion += fragment
+# processed_suggestion = completion.post_process_suggestion(
+# suggestion, request
+# )
+# yield InlineCompletionStreamChunk(
+# type="stream",
+# response={"insertText": processed_suggestion, "token": token},
+# reply_to=request.number,
+# done=False,
+# )
+
+# # finally, send a message confirming that we are done
+# yield InlineCompletionStreamChunk(
+# type="stream",
+# response={"insertText": processed_suggestion, "token": token},
+# reply_to=request.number,
+# done=True,
+# )
diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py
deleted file mode 100644
index 6534a037a..000000000
--- a/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py
+++ /dev/null
@@ -1,78 +0,0 @@
-from logging import Logger
-from typing import Any, Optional
-
-from jupyter_ai.config_manager import ConfigManager
-from jupyter_ai_magics.providers import BaseProvider
-
-
-class CompletionsModelMixin:
- """Mixin class containing methods and attributes used by completions LLM handler."""
-
- handler_kind: str
- settings: dict
- log: Logger
-
- @property
- def jai_config_manager(self) -> ConfigManager:
- return self.settings["jai_config_manager"]
-
- @property
- def model_parameters(self) -> dict[str, dict[str, Any]]:
- return self.settings["model_parameters"]
-
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self._llm: Optional[BaseProvider] = None
- self._llm_params = None
-
- def get_llm(self) -> Optional[BaseProvider]:
- lm_provider = self.jai_config_manager.completions_lm_provider
- lm_provider_params = self.jai_config_manager.completions_lm_provider_params
-
- if not lm_provider or not lm_provider_params:
- return None
-
- curr_lm_id = (
- f'{self._llm.id}:{lm_provider_params["model_id"]}' if self._llm else None
- )
- next_lm_id = (
- f'{lm_provider.id}:{lm_provider_params["model_id"]}'
- if lm_provider
- else None
- )
-
- should_recreate_llm = False
- if curr_lm_id != next_lm_id:
- self.log.info(
- f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}."
- )
- should_recreate_llm = True
- elif self._llm_params != lm_provider_params:
- self.log.info(
- f"{self.handler_kind} model params changed, updating the llm chain."
- )
- should_recreate_llm = True
-
- if should_recreate_llm:
- self._llm = self.create_llm(lm_provider, lm_provider_params)
- self._llm_params = lm_provider_params
-
- return self._llm
-
- def get_model_parameters(
- self, provider: type[BaseProvider], provider_params: dict[str, str]
- ):
- return self.model_parameters.get(
- f"{provider.id}:{provider_params['model_id']}", {}
- )
-
- def create_llm(
- self, provider: type[BaseProvider], provider_params: dict[str, str]
- ) -> BaseProvider:
- unified_parameters = {
- **provider_params,
- **(self.get_model_parameters(provider, provider_params)),
- }
- llm = provider(**unified_parameters)
-
- return llm
diff --git a/packages/jupyter-ai/jupyter_ai/completions/models.py b/packages/jupyter-ai/jupyter_ai/completions/models.py
deleted file mode 100644
index e9679379e..000000000
--- a/packages/jupyter-ai/jupyter_ai/completions/models.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from jupyter_ai_magics.models.completion import (
- CompletionError,
- InlineCompletionItem,
- InlineCompletionList,
- InlineCompletionReply,
- InlineCompletionRequest,
- InlineCompletionStreamChunk,
-)
-
-__all__ = [
- "InlineCompletionRequest",
- "InlineCompletionItem",
- "CompletionError",
- "InlineCompletionList",
- "InlineCompletionReply",
- "InlineCompletionStreamChunk",
-]
diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py
index db19ea29d..05bb1efe0 100644
--- a/packages/jupyter-ai/jupyter_ai/config_manager.py
+++ b/packages/jupyter-ai/jupyter_ai/config_manager.py
@@ -3,16 +3,9 @@
import os
import time
from copy import deepcopy
-from typing import Optional, Union
+from typing import Any, Optional, Union
from deepmerge import always_merger
-from jupyter_ai_magics.utils import (
- AnyProvider,
- EmProvidersDict,
- LmProvidersDict,
- get_em_provider,
- get_lm_provider,
-)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
from traitlets.config import Configurable
@@ -48,17 +41,6 @@ class BlockedModelError(Exception):
pass
-def _validate_provider_authn(config: JaiConfig, provider: type[AnyProvider]):
- # TODO: handle non-env auth strategies
- if not provider.auth_strategy or provider.auth_strategy.type != "env":
- return
-
- if provider.auth_strategy.name not in config.api_keys:
- raise AuthError(
- f"Missing API key for '{provider.auth_strategy.name}' in the config."
- )
-
-
def remove_none_entries(d: dict):
"""
Returns a deep copy of the given dictionary that excludes all top-level
@@ -110,8 +92,6 @@ class ConfigManager(Configurable):
def __init__(
self,
log: Logger,
- lm_providers: LmProvidersDict,
- em_providers: EmProvidersDict,
defaults: dict,
allowed_providers: Optional[list[str]] = None,
blocked_providers: Optional[list[str]] = None,
@@ -123,16 +103,14 @@ def __init__(
super().__init__(*args, **kwargs)
self.log = log
- self._lm_providers = lm_providers
- """List of LM providers."""
- self._em_providers = em_providers
- """List of EM providers."""
-
self._allowed_providers = allowed_providers
self._blocked_providers = blocked_providers
self._allowed_models = allowed_models
self._blocked_models = blocked_models
+ self._lm_providers: dict[str, Any] = (
+ {}
+ ) # Placeholder: should be set to actual language model providers
self._defaults = remove_none_entries(defaults)
self._last_read: Optional[int] = None
@@ -175,69 +153,10 @@ def _process_existing_config(self):
with open(self.config_path, encoding="utf-8") as f:
existing_config = json.loads(f.read())
config = JaiConfig(**existing_config)
- validated_config = self._validate_model_ids(config)
# re-write to the file to validate the config and apply any
# updates to the config file immediately
- self._write_config(validated_config)
-
- def _validate_model_ids(self, config):
- lm_provider_keys = ["model_provider_id", "completions_model_provider_id"]
- em_provider_keys = ["embeddings_provider_id"]
- clm_provider_keys = ["completions_model_provider_id"]
-
- # if the currently selected language or embedding model are
- # forbidden, set them to `None` and log a warning.
- for lm_key in lm_provider_keys:
- lm_id = getattr(config, lm_key)
- if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
- self.log.warning(
- f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
- )
- setattr(config, lm_key, None)
- for em_key in em_provider_keys:
- em_id = getattr(config, em_key)
- if em_id is not None and not self._validate_model(em_id, raise_exc=False):
- self.log.warning(
- f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
- )
- setattr(config, em_key, None)
- for clm_key in clm_provider_keys:
- clm_id = getattr(config, clm_key)
- if clm_id is not None and not self._validate_model(clm_id, raise_exc=False):
- self.log.warning(
- f"Completion model {clm_id} is forbidden by current allow/blocklists. Setting to None."
- )
- setattr(config, clm_key, None)
-
- # if the currently selected language or embedding model ids are
- # not associated with models, set them to `None` and log a warning.
- for lm_key in lm_provider_keys:
- lm_id = getattr(config, lm_key)
- if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]:
- self.log.warning(
- f"No language model is associated with '{lm_id}'. Setting to None."
- )
- setattr(config, lm_key, None)
- for em_key in em_provider_keys:
- em_id = getattr(config, em_key)
- if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]:
- self.log.warning(
- f"No embedding model is associated with '{em_id}'. Setting to None."
- )
- setattr(config, em_key, None)
- for clm_key in clm_provider_keys:
- clm_id = getattr(config, clm_key)
- if (
- clm_id is not None
- and not get_lm_provider(clm_id, self._lm_providers)[1]
- ):
- self.log.warning(
- f"No completion model is associated with '{clm_id}'. Setting to None."
- )
- setattr(config, clm_key, None)
-
- return config
+ self._write_config(config)
def _read_config(self) -> JaiConfig:
"""
@@ -268,78 +187,80 @@ def _validate_config(self, config: JaiConfig):
user has specified authentication for all configured models that require
it.
"""
+ # TODO: re-implement this w/ liteLLM
# validate language model config
- if config.model_provider_id:
- _, lm_provider = get_lm_provider(
- config.model_provider_id, self._lm_providers
- )
+ # if config.model_provider_id:
+ # _, lm_provider = get_lm_provider(
+ # config.model_provider_id, self._lm_providers
+ # )
- # verify model is declared by some provider
- if not lm_provider:
- raise ValueError(
- f"No language model is associated with '{config.model_provider_id}'."
- )
+ # # verify model is declared by some provider
+ # if not lm_provider:
+ # raise ValueError(
+ # f"No language model is associated with '{config.model_provider_id}'."
+ # )
- # verify model is not blocked
- self._validate_model(config.model_provider_id)
+ # # verify model is not blocked
+ # self._validate_model(config.model_provider_id)
- # verify model is authenticated
- _validate_provider_authn(config, lm_provider)
+ # # verify model is authenticated
+ # _validate_provider_authn(config, lm_provider)
- # verify fields exist for this model if needed
- if lm_provider.fields and config.model_provider_id not in config.fields:
- config.fields[config.model_provider_id] = {}
+ # # verify fields exist for this model if needed
+ # if lm_provider.fields and config.model_provider_id not in config.fields:
+ # config.fields[config.model_provider_id] = {}
# validate completions model config
- if config.completions_model_provider_id:
- _, completions_provider = get_lm_provider(
- config.completions_model_provider_id, self._lm_providers
- )
-
- # verify model is declared by some provider
- if not completions_provider:
- raise ValueError(
- f"No language model is associated with '{config.completions_model_provider_id}'."
- )
-
- # verify model is not blocked
- self._validate_model(config.completions_model_provider_id)
-
- # verify model is authenticated
- _validate_provider_authn(config, completions_provider)
-
- # verify completions fields exist for this model if needed
- if (
- completions_provider.fields
- and config.completions_model_provider_id
- not in config.completions_fields
- ):
- config.completions_fields[config.completions_model_provider_id] = {}
-
- # validate embedding model config
- if config.embeddings_provider_id:
- _, em_provider = get_em_provider(
- config.embeddings_provider_id, self._em_providers
- )
-
- # verify model is declared by some provider
- if not em_provider:
- raise ValueError(
- f"No embedding model is associated with '{config.embeddings_provider_id}'."
- )
-
- # verify model is not blocked
- self._validate_model(config.embeddings_provider_id)
-
- # verify model is authenticated
- _validate_provider_authn(config, em_provider)
-
- # verify embedding fields exist for this model if needed
- if (
- em_provider.fields
- and config.embeddings_provider_id not in config.embeddings_fields
- ):
- config.embeddings_fields[config.embeddings_provider_id] = {}
+ # if config.completions_model_provider_id:
+ # _, completions_provider = get_lm_provider(
+ # config.completions_model_provider_id, self._lm_providers
+ # )
+
+ # # verify model is declared by some provider
+ # if not completions_provider:
+ # raise ValueError(
+ # f"No language model is associated with '{config.completions_model_provider_id}'."
+ # )
+
+ # # verify model is not blocked
+ # self._validate_model(config.completions_model_provider_id)
+
+ # # verify model is authenticated
+ # _validate_provider_authn(config, completions_provider)
+
+ # # verify completions fields exist for this model if needed
+ # if (
+ # completions_provider.fields
+ # and config.completions_model_provider_id
+ # not in config.completions_fields
+ # ):
+ # config.completions_fields[config.completions_model_provider_id] = {}
+
+ # # validate embedding model config
+ # if config.embeddings_provider_id:
+ # _, em_provider = get_em_provider(
+ # config.embeddings_provider_id, self._em_providers
+ # )
+
+ # # verify model is declared by some provider
+ # if not em_provider:
+ # raise ValueError(
+ # f"No embedding model is associated with '{config.embeddings_provider_id}'."
+ # )
+
+ # # verify model is not blocked
+ # self._validate_model(config.embeddings_provider_id)
+
+ # # verify model is authenticated
+ # _validate_provider_authn(config, em_provider)
+
+ # # verify embedding fields exist for this model if needed
+ # if (
+ # em_provider.fields
+ # and config.embeddings_provider_id not in config.embeddings_fields
+ # ):
+ # config.embeddings_fields[config.embeddings_provider_id] = {}
+ return
def _validate_model(self, model_id: str, raise_exc=True):
"""
@@ -350,7 +271,7 @@ def _validate_model(self, model_id: str, raise_exc=True):
"""
assert model_id is not None
- components = model_id.split(":", 1)
+ components = model_id.split("/", 1)
assert len(components) == 2
provider_id, _ = components
@@ -399,29 +320,6 @@ def _write_config(self, new_config: JaiConfig):
with open(self.config_path, "w") as f:
json.dump(new_config.model_dump(), f, indent=self.indentation_depth)
- def delete_api_key(self, key_name: str):
- config_dict = self._read_config().model_dump()
- required_keys = []
- for provider in [
- self.lm_provider,
- self.em_provider,
- self.completions_lm_provider,
- ]:
- if (
- provider
- and provider.auth_strategy
- and provider.auth_strategy.type == "env"
- ):
- required_keys.append(provider.auth_strategy.name)
-
- if key_name in required_keys:
- raise KeyInUseError(
- "This API key is currently in use by the language or embedding model. Please change the model before deleting the corresponding API key."
- )
-
- config_dict["api_keys"].pop(key_name, None)
- self._write_config(JaiConfig(**config_dict))
-
def update_config(self, config_update: UpdateConfigRequest): # type:ignore
last_write = os.stat(self.config_path).st_mtime_ns
if config_update.last_read and config_update.last_read < last_write:
@@ -449,94 +347,56 @@ def get_config(self):
)
@property
- def lm_gid(self):
+ def chat_model(self) -> str | None:
+ """
+ Returns the model ID of the chat model from AI settings, if any.
+ """
config = self._read_config()
return config.model_provider_id
@property
- def em_gid(self):
- config = self._read_config()
- return config.embeddings_provider_id
-
- @property
- def lm_provider(self):
- return self._get_provider("model_provider_id", self._lm_providers)
-
- @property
- def em_provider(self):
- return self._get_provider("embeddings_provider_id", self._em_providers)
-
- @property
- def completions_lm_provider(self):
- return self._get_provider("completions_model_provider_id", self._lm_providers)
+ def chat_model_params(self) -> dict[str, Any]:
+ return self._provider_params("model_provider_id", self._lm_providers)
- def _get_provider(self, key, listing):
+ def _provider_params(
+ self, provider_id_attr: str, providers: dict
+ ) -> dict[str, Any]:
+ """
+ Returns the parameters for the provider specified by the given attribute.
+ """
config = self._read_config()
- gid = getattr(config, key)
- if gid is None:
- return None
-
- _, Provider = get_lm_provider(gid, listing)
- return Provider
-
- @property
- def lm_provider_params(self):
+ provider_id = getattr(config, provider_id_attr, None)
+ if not provider_id or provider_id not in providers:
+ return {}
+ return providers[provider_id].get("params", {})
return self._provider_params("model_provider_id", self._lm_providers)
@property
- def em_provider_params(self):
- return self._provider_params("embeddings_provider_id", self._em_providers)
+ def embedding_model(self) -> str | None:
+ """
+ Returns the model ID of the embedding model from AI settings, if any.
+ """
+ config = self._read_config()
+ return config.embeddings_provider_id
@property
- def completions_lm_provider_params(self):
- return self._provider_params(
- "completions_model_provider_id", self._lm_providers, completions=True
- )
+ def embedding_model_params(self) -> dict[str, Any]:
+ # TODO
+ return {}
- def _provider_params(self, key, listing, completions: bool = False):
- # read config
+ @property
+ def completion_model(self) -> str | None:
+ """
+ Returns the model ID of the completion model from AI settings, if any.
+ """
config = self._read_config()
+ return config.completions_model_provider_id
- # get model ID (without provider ID component) from model universal ID
- # (with provider component).
- model_uid = getattr(config, key)
- if not model_uid:
- return None
- model_id = model_uid.split(":", 1)[1]
-
- # get config fields (e.g. base API URL, etc.)
- if completions:
- fields = config.completions_fields.get(model_uid, {})
- elif key == "embeddings_provider_id":
- fields = config.embeddings_fields.get(model_uid, {})
- else:
- fields = config.fields.get(model_uid, {})
-
- # exclude empty fields
- # TODO: modify the config manager to never save empty fields in the
- # first place.
- fields = {
- k: None if isinstance(v, str) and not len(v) else v
- for k, v in fields.items()
- }
-
- # get authn fields
- _, Provider = (
- get_em_provider(model_uid, listing)
- if key == "embeddings_provider_id"
- else get_lm_provider(model_uid, listing)
- )
- authn_fields = {}
- if Provider.auth_strategy and Provider.auth_strategy.type == "env":
- keyword_param = (
- Provider.auth_strategy.keyword_param
- or Provider.auth_strategy.name.lower()
- )
- key_name = Provider.auth_strategy.name
- authn_fields[keyword_param] = config.api_keys[key_name]
+ @property
+ def completion_model_params(self):
+ # TODO
+ return {}
- return {
- "model_id": model_id,
- **fields,
- **authn_fields,
- }
+ def delete_api_key(self, key_name: str):
+ # TODO: store in .env files
+ pass
diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py
index 07446746a..29afdd192 100644
--- a/packages/jupyter-ai/jupyter_ai/extension.py
+++ b/packages/jupyter-ai/jupyter_ai/extension.py
@@ -1,13 +1,10 @@
import os
import time
-import types
from asyncio import get_event_loop_policy
from functools import partial
from typing import TYPE_CHECKING, Optional
import traitlets
-from jupyter_ai_magics import BaseProvider
-from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_events import EventLogger
from jupyter_server.extension.application import ExtensionApp
from jupyter_server.serverapp import ServerApp
@@ -24,13 +21,12 @@
from .completions.handlers import DefaultInlineCompletionHandler
from .config_manager import ConfigManager
from .handlers import (
- ApiKeysHandler,
- EmbeddingsModelProviderHandler,
GlobalConfigHandler,
InterruptStreamingHandler,
- ModelProviderHandler,
)
from .personas import PersonaManager
+from .secrets.secrets_manager import EnvSecretsManager
+from .secrets.secrets_rest_api import SecretsRestAPI
if TYPE_CHECKING:
from asyncio import AbstractEventLoop
@@ -55,16 +51,17 @@
JUPYTER_COLLABORATION_EVENTS_URI,
)
+from .model_providers.model_handlers import ChatModelEndpoint
+
class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [ # type:ignore[assignment]
- (r"api/ai/api_keys/(?P\w+)/?", ApiKeysHandler),
(r"api/ai/config/?", GlobalConfigHandler),
(r"api/ai/chats/stop_streaming/?", InterruptStreamingHandler),
- (r"api/ai/providers/?", ModelProviderHandler),
- (r"api/ai/providers/embeddings/?", EmbeddingsModelProviderHandler),
(r"api/ai/completion/inline/?", DefaultInlineCompletionHandler),
+ (r"api/ai/models/chat/?", ChatModelEndpoint),
+ (r"api/ai/secrets/?", SecretsRestAPI),
(
r"api/ai/static/jupyternaut.svg()/?",
StaticFileHandler,
@@ -143,7 +140,17 @@ class AiExtension(ExtensionApp):
config=True,
)
- default_language_model = Unicode(
+ initial_chat_model = Unicode(
+ default_value=None,
+ allow_none=True,
+ help="""
+ Default language model to use, as string in the format
+ :, defaults to None.
+ """,
+ config=True,
+ )
+
+ initial_language_model = Unicode(
default_value=None,
allow_none=True,
help="""
@@ -297,22 +304,14 @@ def on_change(
def initialize_settings(self):
start = time.time()
- # Read from allowlist and blocklist
- restrictions = {
- "allowed_providers": self.allowed_providers,
- "blocked_providers": self.blocked_providers,
- }
- self.settings["allowed_models"] = self.allowed_models
- self.settings["blocked_models"] = self.blocked_models
+ # Log traitlets configuration
self.log.info(f"Configured provider allowlist: {self.allowed_providers}")
self.log.info(f"Configured provider blocklist: {self.blocked_providers}")
self.log.info(f"Configured model allowlist: {self.allowed_models}")
self.log.info(f"Configured model blocklist: {self.blocked_models}")
- self.settings["model_parameters"] = self.model_parameters
self.log.info(f"Configured model parameters: {self.model_parameters}")
-
defaults = {
- "model_provider_id": self.default_language_model,
+ "model_provider_id": self.initial_language_model,
"embeddings_provider_id": self.default_embeddings_model,
"completions_model_provider_id": self.default_completions_model,
"api_keys": self.default_api_keys,
@@ -321,20 +320,10 @@ def initialize_settings(self):
"completions_fields": self.model_parameters,
}
- # Fetch LM & EM providers
- self.settings["lm_providers"] = get_lm_providers(
- log=self.log, restrictions=restrictions
- )
- self.settings["em_providers"] = get_em_providers(
- log=self.log, restrictions=restrictions
- )
-
+ # Initialize ConfigManager
self.settings["jai_config_manager"] = ConfigManager(
- # traitlets configuration, not JAI configuration.
config=self.config,
log=self.log,
- lm_providers=self.settings["lm_providers"],
- em_providers=self.settings["em_providers"],
allowed_providers=self.allowed_providers,
blocked_providers=self.blocked_providers,
allowed_models=self.allowed_models,
@@ -342,23 +331,21 @@ def initialize_settings(self):
defaults=defaults,
)
- # Expose a subset of settings as read-only to the providers
- BaseProvider.server_settings = types.MappingProxyType(
- self.serverapp.web_app.settings
- )
-
- self.log.info("Registered providers.")
-
- self.log.info(f"Registered {self.name} server extension")
+ # Initialize SecretsManager
+ self.settings["jai_secrets_manager"] = EnvSecretsManager(parent=self)
+ # Bind event loop to settings dictionary
self.settings["jai_event_loop"] = self.event_loop
- # Create empty dictionary for events communicating that
- # message generation/streaming got interrupted.
+ # Bind dictionary of interrupts to settings dictionary.
+ # Each key is a message ID, each value is an asyncio.Event.
+ # When a message's interrupt event is set, the response is halted.
self.settings["jai_message_interrupted"] = {}
- latency_ms = round((time.time() - start) * 1000)
- self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.")
+ # Log server extension startup time
+ self.log.info(f"Registered {self.name} server extension")
+ startup_time = round((time.time() - start) * 1000)
+ self.log.info(f"Initialized Jupyter AI server extension in {startup_time} ms.")
async def stop_extension(self):
"""
@@ -378,7 +365,10 @@ async def _stop_extension(self):
Private method that defines the cleanup code to run when the server is
stopping.
"""
- # TODO: explore if cleanup is necessary
+ secrets_manager = self.settings.get("jai_secrets_manager", None)
+
+ if secrets_manager:
+ secrets_manager.stop()
def _init_persona_manager(
self, room_id: str, ychat: YChat
@@ -447,7 +437,6 @@ def _link_jupyter_server_extension(self, server_app: ServerApp):
".git", # Git version control directory
".venv", # Python virtual environment directory
"venv", # Python virtual environment directory
- ".env", # Environment variable files
"node_modules", # Node.js dependencies directory
".pytest_cache", # PyTest cache directory
".mypy_cache", # MyPy type checker cache directory
diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py
index fefba68c5..db81d6b65 100644
--- a/packages/jupyter-ai/jupyter_ai/handlers.py
+++ b/packages/jupyter-ai/jupyter_ai/handlers.py
@@ -1,5 +1,3 @@
-from typing import TYPE_CHECKING, Optional, cast
-
from jupyter_ai.config_manager import ConfigManager, KeyEmptyError, WriteConflictError
from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
from pydantic import ValidationError
@@ -7,125 +5,6 @@
from tornado.web import HTTPError
from .config import UpdateConfigRequest
-from .models import (
- ListProvidersEntry,
- ListProvidersResponse,
-)
-
-if TYPE_CHECKING:
- from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
- from jupyter_ai_magics.providers import BaseProvider
-
-
-class ProviderHandler(BaseAPIHandler):
- """
- Helper base class used for HTTP handlers hosting endpoints relating to
- providers. Wrapper around BaseAPIHandler.
- """
-
- @property
- def lm_providers(self) -> dict[str, "BaseProvider"]:
- return self.settings["lm_providers"]
-
- @property
- def em_providers(self) -> dict[str, "BaseEmbeddingsProvider"]:
- return self.settings["em_providers"]
-
- @property
- def allowed_models(self) -> Optional[list[str]]:
- return self.settings["allowed_models"]
-
- @property
- def blocked_models(self) -> Optional[list[str]]:
- return self.settings["blocked_models"]
-
- def _filter_blocked_models(self, providers: list[ListProvidersEntry]):
- """
- Satisfy the model-level allow/blocklist by filtering models accordingly.
- The provider-level allow/blocklist is already handled in
- `AiExtension.initialize_settings()`.
- """
- if self.blocked_models is None and self.allowed_models is None:
- return providers
-
- def filter_predicate(local_model_id: str):
- model_id = provider.id + ":" + local_model_id
- if self.blocked_models:
- return model_id not in self.blocked_models
- else:
- return model_id in cast(list, self.allowed_models)
-
- # filter out every model w/ model ID according to allow/blocklist
- for provider in providers:
- provider.models = list(filter(filter_predicate, provider.models or []))
- provider.chat_models = list(
- filter(filter_predicate, provider.chat_models or [])
- )
- provider.completion_models = list(
- filter(filter_predicate, provider.completion_models or [])
- )
-
- # filter out every provider with no models which satisfy the allow/blocklist, then return
- return filter((lambda p: len(p.models) > 0), providers)
-
-
-class ModelProviderHandler(ProviderHandler):
- @web.authenticated
- def get(self):
- providers = []
-
- # Step 1: gather providers
- for provider in self.lm_providers.values():
- optionals = {}
- if provider.model_id_label:
- optionals["model_id_label"] = provider.model_id_label
-
- providers.append(
- ListProvidersEntry(
- id=provider.id,
- name=provider.name,
- models=provider.models,
- chat_models=provider.chat_models(),
- completion_models=provider.completion_models(),
- help=provider.help,
- auth_strategy=provider.auth_strategy,
- registry=provider.registry,
- fields=provider.fields,
- **optionals,
- )
- )
-
- # Step 2: sort & filter providers
- providers = self._filter_blocked_models(providers)
- providers = sorted(providers, key=lambda p: p.name)
-
- # Finally, yield response.
- response = ListProvidersResponse(providers=providers)
- self.finish(response.model_dump_json())
-
-
-class EmbeddingsModelProviderHandler(ProviderHandler):
- @web.authenticated
- def get(self):
- providers = []
- for provider in self.em_providers.values():
- providers.append(
- ListProvidersEntry(
- id=provider.id,
- name=provider.name,
- models=provider.models,
- help=provider.help,
- auth_strategy=provider.auth_strategy,
- registry=provider.registry,
- fields=provider.fields,
- )
- )
-
- providers = self._filter_blocked_models(providers)
- providers = sorted(providers, key=lambda p: p.name)
-
- response = ListProvidersResponse(providers=providers)
- self.finish(response.model_dump_json())
class GlobalConfigHandler(BaseAPIHandler):
@@ -165,19 +44,6 @@ def post(self):
) from e
-class ApiKeysHandler(BaseAPIHandler):
- @property
- def config_manager(self) -> ConfigManager: # type:ignore[override]
- return self.settings["jai_config_manager"]
-
- @web.authenticated
- def delete(self, api_key_name: str):
- try:
- self.config_manager.delete_api_key(api_key_name)
- except Exception as e:
- raise HTTPError(500, str(e))
-
-
class InterruptStreamingHandler(BaseAPIHandler):
"""Interrupt a current message streaming"""
diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py
deleted file mode 100644
index f2d6d1007..000000000
--- a/packages/jupyter-ai/jupyter_ai/history.py
+++ /dev/null
@@ -1,61 +0,0 @@
-from typing import Optional
-
-from jupyterlab_chat.models import Message as JChatMessage
-from jupyterlab_chat.ychat import YChat
-from langchain_core.chat_history import BaseChatMessageHistory
-from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
-
-
-class YChatHistory(BaseChatMessageHistory):
- """
- An implementation of `BaseChatMessageHistory` that returns the preceding `k`
- exchanges (`k * 2` messages) from the given YChat model.
-
- If `k` is set to `None`, then this class returns all preceding messages.
-
- TODO: Consider just defining `k` as the number of messages and default to 4.
- """
-
- def __init__(self, ychat: YChat, k: Optional[int] = None):
- self.ychat = ychat
- self.k = k
-
- @property
- def messages(self) -> list[BaseMessage]: # type:ignore[override]
- """
- Returns the last `2 * k` messages preceding the latest message. If
- `k` is set to `None`, return all preceding messages.
- """
- # TODO: consider bounding history based on message size (e.g. total
- # char/token count) instead of message count.
- all_messages = self.ychat.get_messages()
-
- # gather last k * 2 messages and return
- # we exclude the last message since that is the human message just
- # submitted by a user.
- start_idx = 0 if self.k is None else -2 * self.k - 1
- recent_messages: list[JChatMessage] = all_messages[start_idx:-1]
-
- return self._convert_to_langchain_messages(recent_messages)
-
- def _convert_to_langchain_messages(self, jchat_messages: list[JChatMessage]):
- """
- Accepts a list of Jupyter Chat messages, and returns them as a list of
- LangChain messages.
- """
- messages: list[BaseMessage] = []
- for jchat_message in jchat_messages:
- if jchat_message.sender.startswith("jupyter-ai-personas::"):
- messages.append(AIMessage(content=jchat_message.body))
- else:
- messages.append(HumanMessage(content=jchat_message.body))
-
- return messages
-
- def add_message(self, message: BaseMessage) -> None:
- # do nothing when other LangChain objects call this method, since
- # message history is maintained by the `YChat` shared document.
- return
-
- def clear(self):
- raise NotImplementedError()
diff --git a/packages/jupyter-ai/jupyter_ai/model_providers/model_handlers.py b/packages/jupyter-ai/jupyter_ai/model_providers/model_handlers.py
new file mode 100644
index 000000000..da3f8f382
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/model_providers/model_handlers.py
@@ -0,0 +1,28 @@
+from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
+from pydantic import BaseModel
+from tornado import web
+
+from .model_list import CHAT_MODELS
+
+
+class ChatModelEndpoint(BaseAPIHandler):
+ """
+ A handler class that defines the `/api/ai/models/chat` endpoint.
+
+ - `GET /api/ai/models/chat`: returns list of all chat models.
+
+ - `GET /api/ai/models/chat?id=`: returns info on that model (TODO)
+ """
+
+ @web.authenticated
+ def get(self):
+ response = ListChatModelsResponse(chat_models=CHAT_MODELS)
+ self.finish(response.model_dump_json())
+
+
+class ListChatModelsResponse(BaseModel):
+ chat_models: list[str]
+
+
+class ListEmbeddingModelsResponse(BaseModel):
+ embedding_models: list[str]
diff --git a/packages/jupyter-ai/jupyter_ai/model_providers/model_list.py b/packages/jupyter-ai/jupyter_ai/model_providers/model_list.py
new file mode 100644
index 000000000..acd899bcb
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/model_providers/model_list.py
@@ -0,0 +1,36 @@
+from litellm import all_embedding_models, models_by_provider
+
+chat_model_ids = []
+embedding_model_ids = []
+embedding_model_set = set(all_embedding_models)
+
+for provider_name in models_by_provider:
+ for model_name in models_by_provider[provider_name]:
+ model_name: str = model_name
+
+ if model_name.startswith(f"{provider_name}/"):
+ model_id = model_name
+ else:
+ model_id = f"{provider_name}/{model_name}"
+
+ is_embedding = (
+ model_name in embedding_model_set
+ or model_id in embedding_model_set
+ or "embed" in model_id
+ )
+
+ if is_embedding:
+ embedding_model_ids.append(model_id)
+ else:
+ chat_model_ids.append(model_id)
+
+
+CHAT_MODELS = sorted(chat_model_ids)
+"""
+List of chat model IDs, following the `litellm` syntax.
+"""
+
+EMBEDDING_MODELS = sorted(embedding_model_ids)
+"""
+List of embedding model IDs, following the `litellm` syntax.
+"""
diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py
index 1c54e0e0f..ff9dae735 100644
--- a/packages/jupyter-ai/jupyter_ai/models.py
+++ b/packages/jupyter-ai/jupyter_ai/models.py
@@ -1,12 +1,13 @@
from typing import Optional
-from jupyter_ai_magics.providers import AuthStrategy, Field
from pydantic import BaseModel
DEFAULT_CHUNK_SIZE = 2000
DEFAULT_CHUNK_OVERLAP = 100
+# TODO: Delete this once the new Models API can return these properties.
+# This is just being kept as a reference.
class ListProvidersEntry(BaseModel):
"""Model provider with supported models
and provider's authentication strategy
@@ -17,22 +18,8 @@ class ListProvidersEntry(BaseModel):
model_id_label: Optional[str] = None
models: list[str]
help: Optional[str] = None
- auth_strategy: AuthStrategy
+ # auth_strategy: AuthStrategy
registry: bool
- fields: list[Field]
+ # fields: list[Field]
chat_models: Optional[list[str]] = None
completion_models: Optional[list[str]] = None
-
-
-class ListProvidersResponse(BaseModel):
- providers: list[ListProvidersEntry]
-
-
-class IndexedDir(BaseModel):
- path: str
- chunk_size: int = DEFAULT_CHUNK_SIZE
- chunk_overlap: int = DEFAULT_CHUNK_OVERLAP
-
-
-class IndexMetadata(BaseModel):
- dirs: list[IndexedDir]
diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py
index 985e40078..310901b1c 100644
--- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py
+++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py
@@ -20,6 +20,8 @@
if TYPE_CHECKING:
from collections.abc import AsyncIterator
+ from litellm import ModelResponseStream
+
from .persona_manager import PersonaManager
@@ -233,10 +235,13 @@ def as_user_dict(self) -> dict[str, Any]:
user = self.as_user()
return asdict(user)
- async def stream_message(self, reply_stream: "AsyncIterator") -> None:
+ async def stream_message(
+ self, reply_stream: "AsyncIterator[ModelResponseStream | str]"
+ ) -> None:
"""
Takes an async iterator, dubbed the 'reply stream', and streams it to a
- new message by this persona in the YChat.
+ new message by this persona in the YChat. The async iterator may yield
+ either strings or `litellm.ModelResponseStream` objects. Details:
- Creates a new message upon receiving the first chunk from the reply
stream, then continuously updates it until the stream is closed.
@@ -248,6 +253,15 @@ async def stream_message(self, reply_stream: "AsyncIterator") -> None:
try:
self.awareness.set_local_state_field("isWriting", True)
async for chunk in reply_stream:
+ # Coerce LiteLLM stream chunk to a string delta
+ if not isinstance(chunk, str):
+ chunk = chunk.choices[0].delta.content
+
+ # LiteLLM streams always terminate with an empty chunk, so we
+ # ignore and continue when this occurs.
+ if not chunk:
+ continue
+
if (
stream_id
and stream_id in self.message_interrupted.keys()
diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py
index 495dd5c7b..633ca085e 100644
--- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py
+++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py
@@ -1,12 +1,14 @@
-from typing import Any
+from typing import Any, Optional
from jupyterlab_chat.models import Message
-from langchain_core.output_parsers import StrOutputParser
-from langchain_core.runnables.history import RunnableWithMessageHistory
+from litellm import acompletion
-from ...history import YChatHistory
from ..base_persona import BasePersona, PersonaDefaults
-from .prompt_template import JUPYTERNAUT_PROMPT_TEMPLATE, JupyternautVariables
+from ..persona_manager import SYSTEM_USERNAME
+from .prompt_template import (
+ JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE,
+ JupyternautSystemPromptArgs,
+)
class JupyternautPersona(BasePersona):
@@ -27,40 +29,72 @@ def defaults(self):
)
async def process_message(self, message: Message) -> None:
- if not self.config_manager.lm_provider:
+ if not self.config_manager.chat_model:
self.send_message(
- "No language model provider configured. Please set one in the Jupyter AI settings."
+ "No chat model is configured.\n\n"
+ "You must set one first in the Jupyter AI settings, found in 'Settings > AI Settings' from the menu bar."
)
return
- provider_name = self.config_manager.lm_provider.name
- model_id = self.config_manager.lm_provider_params["model_id"]
+ model_id = self.config_manager.chat_model
+ context_as_messages = self.get_context_as_messages(model_id, message)
+ response_aiter = await acompletion(
+ model=model_id,
+ messages=[
+ *context_as_messages,
+ {
+ "role": "user",
+ "content": message.body,
+ },
+ ],
+ stream=True,
+ )
- # Process file attachments and include their content in the context
- context = self.process_attachments(message)
+ await self.stream_message(response_aiter)
- runnable = self.build_runnable()
- variables = JupyternautVariables(
- input=message.body,
+ def get_context_as_messages(
+ self, model_id: str, message: Message
+ ) -> list[dict[str, Any]]:
+ """
+ Returns the current context, including attachments and recent messages,
+ as a list of messages accepted by `litellm.acompletion()`.
+ """
+ system_msg_args = JupyternautSystemPromptArgs(
model_id=model_id,
- provider_name=provider_name,
persona_name=self.name,
- context=context,
- )
- variables_dict = variables.model_dump()
- reply_stream = runnable.astream(variables_dict)
- await self.stream_message(reply_stream)
-
- def build_runnable(self) -> Any:
- # TODO: support model parameters. maybe we just add it to lm_provider_params in both 2.x and 3.x
- llm = self.config_manager.lm_provider(**self.config_manager.lm_provider_params)
- runnable = JUPYTERNAUT_PROMPT_TEMPLATE | llm | StrOutputParser()
-
- runnable = RunnableWithMessageHistory(
- runnable=runnable, # type:ignore[arg-type]
- get_session_history=lambda: YChatHistory(ychat=self.ychat, k=2),
- input_messages_key="input",
- history_messages_key="history",
- )
+ context=self.process_attachments(message),
+ ).model_dump()
+
+ system_msg = {
+ "role": "system",
+ "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args),
+ }
+
+ context_as_messages = [system_msg, *self._get_history_as_messages()]
+ return context_as_messages
+
+ def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]:
+ """
+ Returns the current history as a list of messages accepted by
+ `litellm.acompletion()`.
+ """
+ # TODO: consider bounding history based on message size (e.g. total
+ # char/token count) instead of message count.
+ all_messages = self.ychat.get_messages()
+
+ # gather last k * 2 messages and return
+ # we exclude the last message since that is the human message just
+ # submitted by a user.
+ start_idx = 0 if k is None else -2 * k - 1
+ recent_messages: list[Message] = all_messages[start_idx:-1]
+
+ history: list[dict[str, Any]] = []
+ for msg in recent_messages:
+ role = (
+ "assistant"
+ if msg.sender.startswith("jupyter-ai-personas::")
+ else "system" if msg.sender == SYSTEM_USERNAME else "user"
+ )
+ history.append({"role": role, "content": msg.body})
- return runnable
+ return history
diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py
index cd5a6ce9b..05cb7b956 100644
--- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py
+++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py
@@ -1,11 +1,6 @@
from typing import Optional
-from langchain.prompts import (
- ChatPromptTemplate,
- HumanMessagePromptTemplate,
- MessagesPlaceholder,
- SystemMessagePromptTemplate,
-)
+from jinja2 import Template
from pydantic import BaseModel
_JUPYTERNAUT_SYSTEM_PROMPT_FORMAT = """
@@ -17,7 +12,7 @@
When installed, Jupyter AI adds a chat experience in JupyterLab that allows multiple users to collaborate with one or more agents like yourself.
-You are not a language model, but rather an AI agent powered by a foundation model `{{model_id}}`, provided by '{{provider_name}}'.
+You are not a language model, but rather an AI agent powered by a foundation model `{{model_id}}`.
You are receiving a request from a user in JupyterLab. Your goal is to fulfill this request to the best of your ability.
@@ -48,28 +43,13 @@
""".strip()
-JUPYTERNAUT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
- [
- SystemMessagePromptTemplate.from_template(
- _JUPYTERNAUT_SYSTEM_PROMPT_FORMAT, template_format="jinja2"
- ),
- MessagesPlaceholder(variable_name="history"),
- HumanMessagePromptTemplate.from_template("{input}"),
- ]
-)
-
-class JupyternautVariables(BaseModel):
- """
- Variables expected by `JUPYTERNAUT_PROMPT_TEMPLATE`, defined as a Pydantic
- data model for developer convenience.
+JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE: Template = Template(
+ _JUPYTERNAUT_SYSTEM_PROMPT_FORMAT
+)
- Call the `.model_dump()` method on an instance to convert it to a Python
- dictionary.
- """
- input: str
+class JupyternautSystemPromptArgs(BaseModel):
persona_name: str
- provider_name: str
model_id: str
context: Optional[str] = None
diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py
new file mode 100644
index 000000000..782dd4533
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py
@@ -0,0 +1,379 @@
+from __future__ import annotations
+
+import asyncio
+import os
+from datetime import datetime
+from io import StringIO
+from typing import TYPE_CHECKING
+
+from dotenv import dotenv_values, load_dotenv
+from tornado.web import HTTPError
+from traitlets.config import LoggingConfigurable
+
+from .secrets_types import SecretsList
+from .secrets_utils import build_updated_dotenv
+
+if TYPE_CHECKING:
+ import logging
+ from typing import Any
+
+ from jupyter_server.services.contents.filemanager import AsyncFileContentsManager
+
+ from ..extension import AiExtension
+
+
+class EnvSecretsManager(LoggingConfigurable):
+ """
+ The default secrets manager implementation.
+
+ TODO: Create a `BaseSecretsManager` class and add an
+ `AiExtension.secrets_manager_class` configurable trait to allow custom
+ implementations.
+
+ TODO: Add a `EnvSecretsManager.dotenv_path` configurable trait to allow
+ users to change the path of the `.env` file.
+
+ TODO: Add a `EnvSecretsManager.envvar_secrets` configurable trait to allow
+ users to pass a list of glob expressions that define which environment
+ variables are listed as secrets in the UI. This should default to `*TOKEN*,
+ *SECRET*, *KEY*`.
+ """
+
+ parent: AiExtension
+ """
+ The parent `AiExtension` class.
+
+ NOTE: This attribute is automatically set by the `LoggingConfigurable`
+ parent class. This annotation exists only to help type checkers like `mypy`.
+ """
+
+ log: logging.Logger
+ """
+ The logger used by by this instance.
+
+ NOTE: This attribute is automatically set by the `LoggingConfigurable`
+ parent class. This annotation exists only to help type checkers like `mypy`.
+ """
+
+ _last_modified: datetime | None
+ """
+ The 'last modified' timestamp on the '.env' file retrieved in the previous
+ tick of the `_watch_dotenv()` background task.
+ """
+
+ _initial_env: dict[str, str]
+ """
+ Dictionary containing the initial environment variables passed to this
+ process. Set to `dict(os.environ)` exactly once on init.
+
+ This attribute should not be set more than once, since secrets loaded from
+ the `.env` file are added to `os.environ` after this class initializes.
+ """
+
+ _dotenv_env: dict[str, str]
+ """
+ Dictionary containing the environment variables defined in the `.env` file.
+ If no `.env` file exists, this will be an empty dictionary. This attribute
+ is continuously updated via the `_watch_dotenv()` background task.
+ """
+
+ _dotenv_lock: asyncio.Lock
+ """
+ Lock which must be held while reading or writing to the `.env` file from the
+ `ContentsManager`.
+ """
+
+ @property
+ def contents_manager(self) -> AsyncFileContentsManager:
+ return self.parent.serverapp.contents_manager
+
+ @property
+ def event_loop(self) -> asyncio.AbstractEventLoop:
+ return self.parent.event_loop
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Set instance attributes
+ self._last_modified = None
+ self._initial_env = dict(os.environ)
+ self._dotenv_env = {}
+ self._dotenv_lock = asyncio.Lock()
+
+ # Start `_watch_dotenv()` task to automatically update the environment
+ # variables when `.env` is modified
+ self._watch_dotenv_task = self.event_loop.create_task(self._watch_dotenv())
+
+ async def _watch_dotenv(self) -> None:
+ """
+ Watches the `.env` file and automatically responds to changes.
+ """
+ while True:
+ await asyncio.sleep(2)
+
+ # Fetch file content and its last modified timestamp
+ try:
+ async with self._dotenv_lock:
+ dotenv_file = await self.contents_manager.get(".env", content=True)
+ dotenv_content = dotenv_file.get("content")
+ assert isinstance(dotenv_content, str)
+ except HTTPError as e:
+ # Continue if file does not exist, otherwise re-raise
+ if e.status_code == 404:
+ self._handle_dotenv_notfound()
+ continue
+ except Exception:
+ self.log.exception("Unknown exception in `_watch_dotenv()`:")
+ continue
+
+ # Continue if the `.env` file was already processed and its content
+ # is unchanged.
+ if self._last_modified == dotenv_file["last_modified"]:
+ continue
+
+ # When this line is reached, the .env file needs to be applied.
+ # Log a statement accordingly, ensure the new `.env` file is listed
+ # in `.gitignore`, and store the latest last modified timestamp.
+ if self._last_modified:
+ # Statement when .env file was modified:
+ self.log.info(
+ "Detected changes to the '.env' file. Re-applying '.env' to the environment..."
+ )
+ else:
+ # Statement when the .env file was just created, or when this is
+ # the first iteration and a .env file already exists:
+ self.log.info(
+ "Detected '.env' file at the workspace root. Applying '.env' to the environment..."
+ )
+ self.event_loop.create_task(self._ensure_dotenv_gitignored())
+ self._last_modified = dotenv_file["last_modified"]
+
+ # Apply the latest `.env` file to the environment.
+ # See `self._apply_dotenv()` for more info.
+ self._apply_dotenv(dotenv_content)
+
+ def _apply_dotenv(self, content: str) -> None:
+ """
+ Applies a `.env` file to the environment given its content. This method:
+
+ 1. Resets any variables removed from the `.env` file. See
+ `self._reset_envvars()` for more info.
+
+ 2. Stores the parsed `.env` file as a dictionary in `self._dotenv_env`.
+
+ 3. Sets the environment variables in `os.environ` as defined in the
+ `.env` file.
+ """
+ # Parse the latest `.env` file and store it in `self._dotenv_env`,
+ # tracking deleted environment variables in `deleted_envvars`.
+ new_dotenv_env = dotenv_values(stream=StringIO(content))
+ new_dotenv_env = {k: v for k, v in new_dotenv_env.items() if v != None}
+ deleted_envvars = [k for k in self._dotenv_env if k not in new_dotenv_env]
+ self._dotenv_env = new_dotenv_env
+
+ # Apply the new `.env` file to the environment and reset all
+ # environment variables in `deleted_envvars`.
+ if deleted_envvars:
+ self._reset_envvars(deleted_envvars)
+ self.log.info(
+ f"Removed {len(deleted_envvars)} variables from the environment as they were removed from '.env'."
+ )
+ load_dotenv(stream=StringIO(content), override=True)
+ self.log.info("Applied '.env' to the environment.")
+
+ async def _ensure_dotenv_gitignored(self) -> bool:
+ """
+ Ensures the `.env` file is listed in the `.gitignore` file at the
+ workspace root, creating/updating the `.gitignore` file to list `.env`
+ if needed.
+
+ This method is called by the `_watch_dotenv()` background task either on
+ the first iteration when the `.env` file already exists, or when the
+ `.env` file was created on a subsequent iteration.
+ """
+ # Fetch `.gitignore` file.
+ gitignore_file: dict[str, Any] | None = None
+ try:
+ gitignore_file = await self.contents_manager.get(".gitignore", content=True)
+ except HTTPError as e:
+ # Continue if file does not exist, otherwise re-raise
+ if e.status_code == 404:
+ pass
+ else:
+ raise e
+ except Exception:
+ self.log.exception("Unknown exception raised when fetching `.gitignore`:")
+
+ # Return early if the `.gitignore` file exists and already lists `.env`.
+ old_content: str = (gitignore_file or {}).get("content", "")
+ if ".env\n" in old_content:
+ return
+
+ # Otherwise, log something and create/update the `.gitignore` file to
+ # list `.env`.
+ self.log.info("Updating `.gitignore` file to include `.env`...")
+ new_lines = "# Ignore secrets in '.env'\n.env\n"
+ new_content = old_content + "\n" + new_lines if old_content else new_lines
+ try:
+ gitignore_file = await self.contents_manager.save(
+ {
+ "type": "file",
+ "format": "text",
+ "mimetype": "text/plain",
+ "content": new_content,
+ },
+ ".gitignore",
+ )
+ except Exception:
+ self.log.exception("Unknown exception raised when updating `.gitignore`:")
+ self.log.info("Updated `.gitignore` file to include `.env`.")
+
+ def _reset_envvars(self, names: list[str]) -> None:
+ """
+ Resets each environment variable in the given list. Each variable is
+ restored to its initial value in `self._initial_env` if present, and
+ deleted from `os.environ` otherwise.
+ """
+ for ev_name in names:
+ if ev_name in self._initial_env:
+ os.environ[ev_name] = self._initial_env.get(ev_name)
+ else:
+ del os.environ[ev_name]
+
+ def _handle_dotenv_notfound(self) -> None:
+ """
+ Method called by the `_watch_dotenv()` task when the `.env` file is
+ not found.
+ """
+ if self._last_modified:
+ self._last_modified = None
+ if self._dotenv_env:
+ self._reset_envvars(list(self._dotenv_env.keys()))
+ self._dotenv_env = {}
+
+ def list_secrets(self) -> SecretsList:
+ """
+ Lists the names of each environment variable from the workspace `.env`
+ file and the environment variables passed to the Python process. Notes:
+
+ 1. For envvars from the Python process (not set in `.env`), only
+ environment variables whose names contain "KEY" or "TOKEN" or "SECRET"
+ are included.
+
+ 2. Each envvar listed in `.env` is included in the returned list.
+ """
+ dotenv_secrets_names = set()
+ process_secrets_names = set()
+
+ # Add secrets from the initial environment
+ for name in self._initial_env.keys():
+ if "KEY" in name or "TOKEN" in name or "SECRET" in name:
+ process_secrets_names.add(name)
+
+ # Add secrets from .env, if any
+ for name in self._dotenv_env:
+ dotenv_secrets_names.add(name)
+
+ # Remove `TIKTOKEN_CACHE_DIR`, which is set in the initial environment
+ # by some other package and is not a secret.
+ # This gets included otherwise since it contains 'TOKEN' in its name.
+ process_secrets_names.discard("TIKTOKEN_CACHE_DIR")
+
+ return SecretsList(
+ editable_secrets=sorted(list(dotenv_secrets_names)),
+ static_secrets=sorted(list(process_secrets_names)),
+ )
+
+ async def update_secrets(
+ self,
+ updated_secrets: dict[str, str | None],
+ ) -> None:
+ """
+ Accepts a dictionary of secrets to update, adds/updates/deletes them
+ from `.env` accordingly, and applies the updated `.env` file to the
+ environment. Notes:
+
+ - A new `.env` file is created if it does not exist.
+
+ - If the value of a secret in `updated_secrets` is `None`, then the
+ secret is deleted from `.env`.
+
+ - Otherwise, the secret is added/updated in `.env`.
+
+ - A best effort is made at preserving the formatting in the `.env`
+ file. However, inline comments following a environment variable
+ definition on the same line will be deleted.
+ """
+ # Return early if passed an empty dictionary
+ if not updated_secrets:
+ return
+
+ # Hold the lock during the entire duration of the update
+ async with self._dotenv_lock:
+ # Fetch `.env` file content, storing its raw content in
+ # `dotenv_content` and its parsed value as a dict in `dotenv_env`.
+ dotenv_content: str = ""
+ try:
+ dotenv_file = await self.contents_manager.get(".env", content=True)
+ if "content" in dotenv_file:
+ dotenv_content = dotenv_file["content"]
+ assert isinstance(dotenv_content, str)
+ except HTTPError as e:
+ # Continue if file does not exist, otherwise re-raise
+ if e.status_code == 404:
+ pass
+ else:
+ raise e
+ except Exception:
+ self.log.exception(
+ "Unknown exception raised when reading `.env` in response to an update:"
+ )
+
+ # Build the new `.env` file using these variables.
+ # See `build_updated_dotenv()` for more info on how this is done.
+ new_dotenv_content = build_updated_dotenv(dotenv_content, updated_secrets)
+
+ # Return early if no changes are needed in `.env`.
+ if new_dotenv_content is None:
+ return
+
+ # Save new content
+ try:
+ dotenv_file = await self.contents_manager.save(
+ {
+ "type": "file",
+ "format": "text",
+ "mimetype": "text/plain",
+ "content": new_dotenv_content,
+ },
+ ".env",
+ )
+ last_modified = dotenv_file.get("last_modified")
+ assert isinstance(last_modified, datetime)
+ except Exception:
+ self.log.exception("Unknown exception raised when updating `.env`:")
+
+ # If this is a new file, ensure the `.env` file is listed in `.gitignore`.
+ # `self._last_modified == None` should imply the `.env` file did not exist.
+ if not self._last_modified:
+ self.event_loop.create_task(self._ensure_dotenv_gitignored())
+
+ # Update last modified timestamp and apply the new environment.
+ self._last_modified = last_modified
+ # This automatically sets `self._dotenv_env`.
+ self._apply_dotenv(new_dotenv_content)
+ self.log.info("Updated secrets in `.env`.")
+
+ def get_secret(self, secret_name: str) -> str | None:
+ """
+ Returns the value of a secret given its name. The returned secret must
+ NEVER be shared with frontend clients!
+ """
+ # TODO
+
+ def stop(self) -> None:
+ """
+ Stops this instance and any background tasks spawned by this instance.
+ This method should be called if and only if the server is shutting down.
+ """
+ self._watch_dotenv_task.cancel()
diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py
new file mode 100644
index 000000000..c2492f315
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py
@@ -0,0 +1,48 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
+from tornado.web import HTTPError, authenticated
+
+from .secrets_types import UpdateSecretsRequest
+
+if TYPE_CHECKING:
+ from .secrets_manager import EnvSecretsManager
+
+
+class SecretsRestAPI(BaseAPIHandler):
+ """
+ Defines the REST API served at the `/api/ai/secrets` endpoint.
+
+ Methods supported:
+
+ - `GET secrets/`: Returns a list of secrets.
+ - `PUT secrets/`: Add/update/delete a set of secrets.
+ """
+
+ @property
+ def secrets_manager(self) -> EnvSecretsManager: # type:ignore[override]
+ return self.settings["jai_secrets_manager"]
+
+ @authenticated
+ def get(self):
+ response = self.secrets_manager.list_secrets()
+ self.set_status(200)
+ self.finish(response.model_dump_json())
+
+ @authenticated
+ async def put(self):
+ try:
+ # Validate the request body matches the `UpdateSecretsRequest`
+ # expected type
+ request = UpdateSecretsRequest(**self.get_json_body())
+
+ # Dispatch the request to the secrets manager
+ await self.secrets_manager.update_secrets(request.updated_secrets)
+ except Exception as e:
+ self.log.exception("Exception raised when handling PUT /api/ai/secrets/:")
+ raise HTTPError(500, str(e))
+
+ self.set_status(204)
+ self.finish()
diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py
new file mode 100644
index 000000000..8959ee8dc
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py
@@ -0,0 +1,35 @@
+from typing import Optional
+
+from pydantic import BaseModel
+
+
+class SecretsList(BaseModel):
+ """
+ The response type returned by `GET /api/ai/secrets`.
+
+ The response fields only include the names of each secret, and never must
+ never include the value of any secret.
+ """
+
+ editable_secrets: list[str] = []
+ """
+ List of secrets set in the `.env` file. These secrets can be edited.
+ """
+
+ static_secrets: list[str] = []
+ """
+ List of secrets passed as environment variables to the Python process or
+ passed as traitlets configuration to JupyterLab. These secrets cannot be
+ edited.
+
+ Environment variables passed to the Python process are only included if
+ their name contains 'KEY' or 'TOKEN' or 'SECRET'.
+ """
+
+
+class UpdateSecretsRequest(BaseModel):
+ """
+ The request body expected by `PUT /api/ai/secrets`.
+ """
+
+ updated_secrets: dict[str, Optional[str]]
diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py
new file mode 100644
index 000000000..f4887e708
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py
@@ -0,0 +1,176 @@
+from __future__ import annotations
+
+from io import StringIO
+from typing import TYPE_CHECKING
+
+from dotenv import dotenv_values
+from dotenv.parser import parse_stream
+
+if TYPE_CHECKING:
+ import logging
+
+ENVIRONMENT_VAR_REGEX = "^([a-zA-Z_][a-zA-Z_0-9]*)=?(['\"])?"
+"""
+Regex that matches a environment variable definition.
+"""
+
+
+def build_updated_dotenv(
+ dotenv_content: str,
+ updated_secrets: dict[str, str | None],
+ log: logging.Logger | None = None,
+) -> str | None:
+ """
+ Accepts the existing `.env` file as a parsed dictionary of environment
+ variables, along with a dictionary of secrets to update. `None` values
+ indicate the secret should be deleted. Otherwise, the secret will be added
+ to or updated in `.env`.
+
+ This function returns the content of the updated `.env` file as a string,
+ and returns `None` if no updates to `.env` are required.
+
+ NOTE: This function currently deletes inline comments on environment
+ variable definitions. This may be fixed in a future update.
+ """
+ # Return early if no updates were given.
+ if not updated_secrets:
+ return None
+
+ # Parse content of `.env` into a dictionary of environment variables
+ dotenv_env = dotenv_values(stream=StringIO(dotenv_content))
+
+ # Define `secrets_to_add`, `secrets_to_update`, and
+ # `secrets_to_remove`.
+ secrets_to_add: dict[str, str] = {}
+ secrets_to_update: dict[str, str] = {}
+ secrets_to_remove: set[str] = set()
+ if dotenv_env:
+ for name, value in updated_secrets.items():
+ # Case 1: secret should be added to `.env`
+ if value is not None and dotenv_env.get(name, None) == None:
+ secrets_to_add[name] = value
+ continue
+ # Case 2: secret should be updated in `.env`
+ if value is not None and dotenv_env.get(name, None) != None:
+ secrets_to_update[name] = value
+ continue
+ # Case 3: secret should be removed from `.env`
+ if value is None and dotenv_env.get(name, None) != None:
+ secrets_to_remove.add(name)
+ continue
+ else:
+ # Case 4: keys can only be added when a `.env` file is not
+ # present.
+ secrets_to_add = {k: v for k, v in updated_secrets.items() if v is not None}
+
+ # Return early if update has effect.
+ if not (secrets_to_add or secrets_to_update or secrets_to_remove):
+ return None
+
+ # First, handle the case of adding secrets to a new `.env` file.
+ if not dotenv_env:
+ new_content = ""
+ max_i = len(secrets_to_add) - 1
+ for i, (name, value) in enumerate(secrets_to_add.items()):
+ new_content += f'{name}="{value}"\n'
+ if i != max_i:
+ new_content += "\n"
+
+ return new_content
+
+ # Now handle the case of updating an existing `.env` file.
+ # To preserve formatting, multiline variables, and inline comments on
+ # variable defintions, we re-use the parser used by `python_dotenv`.
+ # It is not trivial to re-implement their parser.
+ #
+ # Algorithm overview:
+ #
+ # 1. The `parse_stream()` function returns an Iterator that yields 'Binding'
+ # objects that represent 'parsed chunks' of a `.env` file. Each chunk may
+ # contain:
+ #
+ # - An environment variable definition (`Binding.key is not None`),
+ # - An invalid line (`Binding.error == True`),
+ # - A standalone comment (if neither condition applies).
+ #
+ # 2. (Case 1) Invalid lines and environment variable bindings listed in
+ # `secrets_to_remove` are ignored.
+ #
+ # 3. (Case 2) Environment variable definitions listed in `secrets_to_update`
+ # are appended to `new_content` with the new value.
+ #
+ # 4. (Case 3) All other `Binding` objects are appended to `new_content` as-is.
+ #
+ # 5. Finally, new environment variables listed in `secrets_to_add` are
+ # appended at the end after the `.env` file is fully parsed.
+ new_content = ""
+ for binding in parse_stream(StringIO(dotenv_content)):
+ # Case 1
+ if binding.error or binding.key in secrets_to_remove:
+ continue
+ # Case 2
+ if binding.key in secrets_to_update:
+ name = binding.key
+ # extra logic to preserve formatting as best as we can
+ whitespace_before, whitespace_after = get_whitespace_around(
+ binding.original.string
+ )
+ value = secrets_to_update[name]
+ new_content += whitespace_before
+ new_content += f'{name}="{value}"'
+ new_content += whitespace_after
+ continue
+ # Case 3
+ new_content += binding.original.string
+
+ if secrets_to_add:
+ # Ensure new secrets get put at least 2 lines below the rest
+ if not new_content.endswith("\n"):
+ new_content += "\n\n"
+ elif not new_content.endswith("\n\n"):
+ new_content += "\n"
+
+ max_i = len(secrets_to_add) - 1
+ for i, (name, value) in enumerate(secrets_to_add.items()):
+ new_content += f'{name}="{value}"\n'
+ if i != max_i:
+ new_content += "\n"
+
+ return new_content
+
+
+def get_whitespace_around(text: str) -> tuple[str, str]:
+ """
+ Extract whitespace prefix and suffix from a string.
+
+ Args:
+ text: The input string
+
+ Returns:
+ A tuple of (prefix, suffix) where prefix is the leading whitespace
+ and suffix is the trailing whitespace
+ """
+ if not text:
+ return ("", "")
+
+ # Find prefix (leading whitespace)
+ prefix_end = 0
+ for i, char in enumerate(text):
+ if not char.isspace():
+ prefix_end = i
+ break
+ else:
+ # String is all whitespace
+ return (text, "")
+
+ # Find suffix (trailing whitespace)
+ suffix_start = len(text)
+ for i in range(len(text) - 1, -1, -1):
+ if not text[i].isspace():
+ suffix_start = i + 1
+ break
+
+ prefix = text[:prefix_end]
+ suffix = text[suffix_start:]
+
+ return (prefix, suffix)
diff --git a/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py b/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py
new file mode 100644
index 000000000..2fb4ef2e9
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py
@@ -0,0 +1,127 @@
+from jupyter_ai.secrets.secrets_utils import build_updated_dotenv, get_whitespace_around
+
+
+class TestGetWhitespaceAround:
+ """Test cases for get_whitespace_around function."""
+
+ def test_empty_string(self):
+ """Test with empty string."""
+ prefix, suffix = get_whitespace_around("")
+ assert prefix == ""
+ assert suffix == ""
+
+ def test_no_whitespace(self):
+ """Test with string containing no whitespace."""
+ prefix, suffix = get_whitespace_around("hello")
+ assert prefix == ""
+ assert suffix == ""
+
+ def test_only_prefix_whitespace(self):
+ """Test with string containing only leading whitespace."""
+ prefix, suffix = get_whitespace_around(" hello")
+ assert prefix == " "
+ assert suffix == ""
+
+ def test_only_suffix_whitespace(self):
+ """Test with string containing only trailing whitespace."""
+ prefix, suffix = get_whitespace_around("hello ")
+ assert prefix == ""
+ assert suffix == " "
+
+ def test_both_prefix_and_suffix_whitespace(self):
+ """Test with string containing both leading and trailing whitespace."""
+ prefix, suffix = get_whitespace_around(" hello ")
+ assert prefix == " "
+ assert suffix == " "
+
+ def test_mixed_whitespace_types(self):
+ """Test with mixed whitespace types (spaces, tabs, newlines)."""
+ prefix, suffix = get_whitespace_around(" \t\nhello\n\t ")
+ assert prefix == " \t\n"
+ assert suffix == "\n\t "
+
+ def test_all_whitespace(self):
+ """Test with string containing only whitespace."""
+ prefix, suffix = get_whitespace_around(" ")
+ assert prefix == " "
+ assert suffix == ""
+
+ def test_single_character(self):
+ """Test with single non-whitespace character."""
+ prefix, suffix = get_whitespace_around("x")
+ assert prefix == ""
+ assert suffix == ""
+
+ def test_single_whitespace_character(self):
+ """Test with single whitespace character."""
+ prefix, suffix = get_whitespace_around(" ")
+ assert prefix == " "
+ assert suffix == ""
+
+
+class TestBuildUpdatedDotenv:
+ """Test cases for build_updated_dotenv function."""
+
+ def test_empty_updates(self):
+ """Test with no updates to make."""
+ result = build_updated_dotenv("KEY=value", {})
+ assert result is None
+
+ def test_add_to_empty_dotenv(self):
+ """Test adding secrets to empty dotenv content."""
+ result = build_updated_dotenv("", {"NEW_KEY": "new_value"})
+ assert result == 'NEW_KEY="new_value"\n'
+
+ def test_add_multiple_to_empty_dotenv(self):
+ """Test adding multiple secrets to empty dotenv content."""
+ result = build_updated_dotenv("", {"KEY1": "value1", "KEY2": "value2"})
+ expected_lines = result.strip().split("\n")
+ assert len(expected_lines) == 3 # Two keys plus one empty line
+ assert 'KEY1="value1"' in expected_lines
+ assert 'KEY2="value2"' in expected_lines
+
+ def test_update_existing_key(self):
+ """Test updating an existing key."""
+ dotenv_content = 'EXISTING_KEY="old_value"\n'
+ result = build_updated_dotenv(dotenv_content, {"EXISTING_KEY": "new_value"})
+ assert 'EXISTING_KEY="new_value"' in result
+
+ def test_add_new_key_to_existing_dotenv(self):
+ """Test adding a new key to existing dotenv content."""
+ dotenv_content = 'EXISTING_KEY="existing_value"\n'
+ result = build_updated_dotenv(dotenv_content, {"NEW_KEY": "new_value"})
+ assert 'EXISTING_KEY="existing_value"' in result
+ assert 'NEW_KEY="new_value"' in result
+
+ def test_remove_existing_key(self):
+ """Test removing an existing key."""
+ dotenv_content = 'KEY_TO_REMOVE="value"\nKEY_TO_KEEP="value"\n'
+ result = build_updated_dotenv(dotenv_content, {"KEY_TO_REMOVE": None})
+ assert "KEY_TO_REMOVE" not in result
+ assert 'KEY_TO_KEEP="value"' in result
+
+ def test_mixed_operations(self):
+ """Test adding, updating, and removing keys in one operation."""
+ dotenv_content = 'UPDATE_ME="old"\nREMOVE_ME="gone"\nKEEP_ME="same"\n'
+ updates = {"UPDATE_ME": "new", "REMOVE_ME": None, "ADD_ME": "added"}
+ result = build_updated_dotenv(dotenv_content, updates)
+
+ assert 'UPDATE_ME="new"' in result
+ assert "REMOVE_ME" not in result
+ assert 'KEEP_ME="same"' in result
+ assert 'ADD_ME="added"' in result
+
+ def test_preserve_comments_and_empty_lines(self):
+ """Test that comments and empty lines are preserved."""
+ dotenv_content = '# This is a comment\nKEY="value"\n\n# Another comment\n'
+ result = build_updated_dotenv(dotenv_content, {"NEW_KEY": "new_value"})
+
+ assert "# This is a comment" in result
+ assert "# Another comment" in result
+ assert 'KEY="value"' in result
+ assert 'NEW_KEY="new_value"' in result
+
+ def test_delete_last_secret(self):
+ dotenv_content = "KEY='value'"
+ result = build_updated_dotenv(dotenv_content, {"KEY": None})
+ assert isinstance(result, str) and result.strip() == ""
diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py
index b3d00159e..b8b33045d 100644
--- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py
+++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py
@@ -1,53 +1,33 @@
import json
-from types import SimpleNamespace
+
+# from types import SimpleNamespace
from typing import Union
import pytest
-from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler
-from jupyter_ai.completions.models import (
+from jupyter_ai.completions.completion_types import (
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
-from jupyter_ai_magics import BaseProvider
-from langchain_community.llms import FakeListLLM
+from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler
from pytest import fixture
from tornado.httputil import HTTPServerRequest
from tornado.web import Application
-class MockProvider(BaseProvider, FakeListLLM):
- id = "my_provider"
- name = "My Provider"
- model_id_key = "model"
- models = ["model"]
- raise_exc: bool = False
-
- def __init__(self, **kwargs):
- if "responses" not in kwargs:
- kwargs["responses"] = ["Test response"]
- super().__init__(**kwargs)
-
- async def _acall(self, *args, **kwargs):
- if self.raise_exc:
- raise Exception("Test exception")
- else:
- return super()._call(*args, **kwargs)
-
-
class MockCompletionHandler(DefaultInlineCompletionHandler):
def __init__(self, lm_provider=None, lm_provider_params=None, raise_exc=False):
self.request = HTTPServerRequest()
self.application = Application()
self.messages = []
self.tasks = []
- self.settings["jai_config_manager"] = SimpleNamespace(
- completions_lm_provider=lm_provider or MockProvider,
- completions_lm_provider_params=lm_provider_params or {"model_id": "model"},
- )
- self.settings["jai_event_loop"] = SimpleNamespace(
- create_task=lambda x: self.tasks.append(x)
- )
+ # self.settings["jai_config_manager"] = SimpleNamespace(
+ # completions_lm_provider=lm_provider or MockProvider,
+ # completions_lm_provider_params=lm_provider_params or {"model_id": "model"},
+ # )
+ # self.settings["jai_event_loop"] = SimpleNamespace(
+ # create_task=lambda x: self.tasks.append(x)
+ # )
self.settings["model_parameters"] = {}
self._llm_params = {}
self._llm = None
@@ -116,11 +96,11 @@ async def test_handle_request(inline_handler):
)
async def test_handle_request_with_spurious_fragments(response, expected_suggestion):
inline_handler = MockCompletionHandler(
- lm_provider=MockProvider,
- lm_provider_params={
- "model_id": "model",
- "responses": [response],
- },
+ # lm_provider=MockProvider,
+ # lm_provider_params={
+ # "model_id": "model",
+ # "responses": [response],
+ # },
)
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=False
@@ -144,11 +124,11 @@ async def test_handle_request_with_spurious_fragments_stream(
response, expected_suggestion
):
inline_handler = MockCompletionHandler(
- lm_provider=MockProvider,
- lm_provider_params={
- "model_id": "model",
- "responses": [response],
- },
+ # lm_provider=MockProvider,
+ # lm_provider_params={
+ # "model_id": "model",
+ # "responses": [response],
+ # },
)
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
@@ -164,11 +144,11 @@ async def test_handle_request_with_spurious_fragments_stream(
async def test_handle_stream_request():
inline_handler = MockCompletionHandler(
- lm_provider=MockProvider,
- lm_provider_params={
- "model_id": "model",
- "responses": ["test"],
- },
+ # lm_provider=MockProvider,
+ # lm_provider_params={
+ # "model_id": "model",
+ # "responses": ["test"],
+ # },
)
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
@@ -198,12 +178,12 @@ async def test_handle_stream_request():
async def test_handle_request_with_error(inline_handler):
inline_handler = MockCompletionHandler(
- lm_provider=MockProvider,
- lm_provider_params={
- "model_id": "model",
- "responses": ["test"],
- "raise_exc": True,
- },
+ # lm_provider=MockProvider,
+ # lm_provider_params={
+ # "model_id": "model",
+ # "responses": ["test"],
+ # "raise_exc": True,
+ # },
)
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
index 66c48f789..c142fa500 100644
--- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
+++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
@@ -11,7 +11,6 @@
KeyInUseError,
WriteConflictError,
)
-from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from pydantic import ValidationError
@@ -43,12 +42,8 @@ def config_file_with_model_fields(jp_data_dir):
def common_cm_kwargs(config_path):
"""Kwargs that are commonly used when initializing the CM."""
log = logging.getLogger()
- lm_providers = get_lm_providers()
- em_providers = get_em_providers()
return {
"log": log,
- "lm_providers": lm_providers,
- "em_providers": em_providers,
"config_path": config_path,
"allowed_providers": None,
"blocked_providers": None,
@@ -467,8 +462,6 @@ def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields
config_path = config_file_with_model_fields
log = logging.getLogger()
- lm_providers = get_lm_providers()
- em_providers = get_em_providers()
defaults = {
"model_provider_id": None,
@@ -480,8 +473,6 @@ def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields
cm = ConfigManager(
log=log,
- lm_providers=lm_providers,
- em_providers=em_providers,
config_path=config_path,
defaults=defaults,
)
diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py
index 3d731fb5a..16948b663 100644
--- a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py
+++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py
@@ -1,11 +1,8 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
-from unittest import mock
import pytest
from jupyter_ai.extension import AiExtension
-from jupyter_ai_magics import BaseProvider
-from langchain_core.messages import BaseMessage
pytest_plugins = ["pytest_jupyter.jupyter_server"]
@@ -57,36 +54,4 @@ def jp_server_config(jp_server_config):
@pytest.fixture
def ai_extension(jp_serverapp):
- ai = AiExtension()
- # `BaseProvider.server_settings` can be only initialized once; however, the tests
- # may run in parallel setting it with race condition; because we are not testing
- # the `BaseProvider.server_settings` here, we can just mock the setter
- settings_mock = mock.PropertyMock()
- with mock.patch.object(BaseProvider, "server_settings", settings_mock):
- yield ai
-
-
-@pytest.mark.parametrize(
- "max_history,messages_to_add,expected_size",
- [
- # for max_history = 1 we expect to see up to 2 messages (1 human and 1 AI message)
- (1, 4, 2),
- # if there is less than `max_history` messages, all should be returned
- (1, 1, 1),
- # if no limit is set, all messages should be returned
- (None, 9, 9),
- ],
-)
-@pytest.mark.skip("TODO v3: replace this with a unit test for YChatHistory")
-def test_max_chat_history(ai_extension, max_history, messages_to_add, expected_size):
- ai = ai_extension
- ai.default_max_chat_history = max_history
- ai.initialize_settings()
- for i in range(messages_to_add):
- message = BaseMessage(
- content=f"Test message {i}",
- type="test",
- )
- ai.settings["llm_chat_memory"].add_message(message)
-
- assert len(ai.settings["llm_chat_memory"].messages) == expected_size
+ AiExtension()
diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml
index 075b68003..57c037b71 100644
--- a/packages/jupyter-ai/pyproject.toml
+++ b/packages/jupyter-ai/pyproject.toml
@@ -36,6 +36,9 @@ dependencies = [
# NOTE: Make sure to update the corresponding dependency in
# `packages/jupyter-ai/package.json` to match the version range below
"jupyterlab-chat>=0.16.0,<0.17.0",
+ "litellm>=1.73,<2",
+ "jinja2>=3.0,<4",
+ "python_dotenv>=1,<2",
]
dynamic = ["version", "description", "authors", "urls", "keywords"]
@@ -58,7 +61,7 @@ test = [
dev = ["jupyter_ai_magics[dev]"]
-all = ["jupyter_ai_magics[all]", "pypdf", "arxiv"]
+all = ["jupyter_ai_magics[all]"]
[tool.hatch.version]
source = "nodejs"
diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx
index 1f3337f85..ea4018d3e 100644
--- a/packages/jupyter-ai/src/components/chat-settings.tsx
+++ b/packages/jupyter-ai/src/components/chat-settings.tsx
@@ -1,109 +1,35 @@
-import React, { useEffect, useState, useMemo } from 'react';
+import React, { useEffect, useState } from 'react';
import { Box } from '@mui/system';
-import {
- Alert,
- Button,
- IconButton,
- FormControl,
- FormControlLabel,
- FormLabel,
- MenuItem,
- Radio,
- RadioGroup,
- TextField,
- Tooltip,
- CircularProgress
-} from '@mui/material';
+import { IconButton, Tooltip } from '@mui/material';
import SettingsIcon from '@mui/icons-material/Settings';
import WarningAmberIcon from '@mui/icons-material/WarningAmber';
-import { Select } from './select';
-import { AiService } from '../handler';
-import { ModelFields } from './settings/model-fields';
-import { ServerInfoState, useServerInfo } from './settings/use-server-info';
-import { ExistingApiKeys } from './settings/existing-api-keys';
import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
-import { minifyUpdate } from './settings/minify';
-import { useStackingAlert } from './mui-extras/stacking-alert';
-import { RendermimeMarkdown } from './settings/rendermime-markdown';
import { IJaiCompletionProvider } from '../tokens';
-import { getProviderId, getModelLocalId } from '../utils';
+import { ModelIdInput } from './settings/model-id-input';
+// import { ModelParametersInput } from './settings/model-parameters-input';
+import { SecretsSection } from './settings/secrets-section';
type ChatSettingsProps = {
rmRegistry: IRenderMimeRegistry;
completionProvider: IJaiCompletionProvider | null;
openInlineCompleterSettings: () => void;
- // The temporary input options, should be removed when jupyterlab chat is
- // the only chat.
- inputOptions?: boolean;
};
/**
* Component that returns the settings view in the chat panel.
*/
export function ChatSettings(props: ChatSettingsProps): JSX.Element {
- // state fetched on initial render
- const server = useServerInfo();
-
- // initialize alert helper
- const alert = useStackingAlert();
- const apiKeysAlert = useStackingAlert();
-
- // user inputs
- const [lmProvider, setLmProvider] =
- useState(null);
- const [emProvider, setEmProvider] =
- useState(null);
- const [clmProvider, setClmProvider] =
- useState(null);
- const [showLmLocalId, setShowLmLocalId] = useState(false);
- const [showEmLocalId, setShowEmLocalId] = useState(false);
- const [showClmLocalId, setShowClmLocalId] = useState(false);
- const [chatHelpMarkdown, setChatHelpMarkdown] = useState(null);
- const [embeddingHelpMarkdown, setEmbeddingHelpMarkdown] = useState<
- string | null
- >(null);
- const [completionHelpMarkdown, setCompletionHelpMarkdown] = useState<
- string | null
- >(null);
- const [lmLocalId, setLmLocalId] = useState('');
- const [emLocalId, setEmLocalId] = useState('');
- const [clmLocalId, setClmLocalId] = useState('');
-
- const lmGlobalId = useMemo(() => {
- if (!lmProvider) {
- return null;
- }
-
- return lmProvider.id + ':' + lmLocalId;
- }, [lmProvider, lmLocalId]);
-
- const emGlobalId = useMemo(() => {
- if (!emProvider) {
- return null;
- }
-
- return emProvider.id + ':' + emLocalId;
- }, [emProvider, emLocalId]);
-
- const clmGlobalId = useMemo(() => {
- if (!clmProvider) {
- return null;
- }
-
- return clmProvider.id + ':' + clmLocalId;
- }, [clmProvider, clmLocalId]);
-
- const [apiKeys, setApiKeys] = useState>({});
- const [sendWse, setSendWse] = useState(false);
- const [lmFields, setLmFields] = useState>({});
- const [emFields, setEmFields] = useState>({});
- const [clmFields, setClmFields] = useState>({});
+ const [completionModel, setCompletionModel] = useState(null);
const [isCompleterEnabled, setIsCompleterEnabled] = useState(
props.completionProvider && props.completionProvider.isEnabled()
);
+ /**
+ * Effect: Listen to JupyterLab completer settings updates on initial render
+ * and update the `isCompleterEnabled` state variable accordingly.
+ */
useEffect(() => {
const refreshCompleterState = () => {
setIsCompleterEnabled(
@@ -118,551 +44,76 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
};
}, [props.completionProvider]);
- // whether the form is currently saving
- const [saving, setSaving] = useState(false);
-
- /**
- * Effect: initialize inputs after fetching server info.
- */
- useEffect(() => {
- if (server.state !== ServerInfoState.Ready) {
- return;
- }
-
- setLmLocalId(server.chat.lmLocalId);
- setEmLocalId(server.chat.emLocalId);
- setClmLocalId(server.completions.lmLocalId);
- setSendWse(server.config.send_with_shift_enter);
- setChatHelpMarkdown(server.chat.lmProvider?.help ?? null);
- setEmbeddingHelpMarkdown(server.chat.emProvider?.help ?? null);
- setCompletionHelpMarkdown(server.completions.lmProvider?.help ?? null);
- if (server.chat.lmProvider?.registry) {
- setShowLmLocalId(true);
- }
- if (server.chat.emProvider?.registry) {
- setShowEmLocalId(true);
- }
- if (server.completions.lmProvider?.registry) {
- setShowClmLocalId(true);
- }
- setLmProvider(server.chat.lmProvider);
- setClmProvider(server.completions.lmProvider);
- setEmProvider(server.chat.emProvider);
- }, [server]);
-
- /**
- * Effect: re-initialize apiKeys object whenever the selected LM/EM changes.
- * Properties with a value of '' indicate necessary user input.
- */
- useEffect(() => {
- if (server.state !== ServerInfoState.Ready) {
- return;
- }
-
- const newApiKeys: Record = {};
- const lmAuth = lmProvider?.auth_strategy;
- const emAuth = emProvider?.auth_strategy;
- if (
- lmAuth?.type === 'env' &&
- !server.config.api_keys.includes(lmAuth.name)
- ) {
- newApiKeys[lmAuth.name] = '';
- }
- if (lmAuth?.type === 'multienv') {
- lmAuth.names.forEach(apiKey => {
- if (!server.config.api_keys.includes(apiKey)) {
- newApiKeys[apiKey] = '';
- }
- });
- }
-
- if (
- emAuth?.type === 'env' &&
- !server.config.api_keys.includes(emAuth.name)
- ) {
- newApiKeys[emAuth.name] = '';
- }
- if (emAuth?.type === 'multienv') {
- emAuth.names.forEach(apiKey => {
- if (!server.config.api_keys.includes(apiKey)) {
- newApiKeys[apiKey] = '';
- }
- });
- }
-
- setApiKeys(newApiKeys);
- }, [lmProvider, emProvider, server]);
-
- /**
- * Effect: re-initialize fields object whenever the selected LM changes.
- */
- useEffect(() => {
- if (server.state !== ServerInfoState.Ready || !lmGlobalId) {
- return;
- }
-
- const currFields: Record =
- server.config.fields?.[lmGlobalId] ?? {};
- setLmFields(currFields);
-
- if (!emGlobalId) {
- return;
- }
-
- const initEmbeddingModelFields: Record =
- server.config.embeddings_fields?.[emGlobalId] ?? {};
- setEmFields(initEmbeddingModelFields);
-
- if (!clmGlobalId) {
- return;
- }
-
- const initCompleterModelFields: Record =
- server.config.completions_fields?.[clmGlobalId] ?? {};
- setClmFields(initCompleterModelFields);
- }, [server, lmGlobalId, emGlobalId, clmGlobalId]);
-
- const handleSave = async () => {
- // compress fields with JSON values
- if (server.state !== ServerInfoState.Ready) {
- return;
- }
-
- for (const fieldKey in lmFields) {
- const fieldVal = lmFields[fieldKey];
- if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) {
- continue;
- }
-
- try {
- const parsedFieldVal = JSON.parse(fieldVal);
- const compressedFieldVal = JSON.stringify(parsedFieldVal);
- lmFields[fieldKey] = compressedFieldVal;
- } catch (e) {
- continue;
- }
- }
-
- for (const fieldKey in emFields) {
- const fieldVal = emFields[fieldKey];
- if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) {
- continue;
- }
-
- try {
- const parsedFieldVal = JSON.parse(fieldVal);
- const compressedFieldVal = JSON.stringify(parsedFieldVal);
- emFields[fieldKey] = compressedFieldVal;
- } catch (e) {
- continue;
- }
- }
-
- for (const fieldKey in clmFields) {
- const fieldVal = clmFields[fieldKey];
- if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) {
- continue;
- }
-
- try {
- const parsedFieldVal = JSON.parse(fieldVal);
- const compressedFieldVal = JSON.stringify(parsedFieldVal);
- clmFields[fieldKey] = compressedFieldVal;
- } catch (e) {
- continue;
- }
- }
-
- let updateRequest: AiService.UpdateConfigRequest = {
- model_provider_id: lmGlobalId,
- embeddings_provider_id: emGlobalId,
- api_keys: apiKeys,
- fields: lmGlobalId ? { [lmGlobalId]: lmFields } : {},
- completions_fields: clmGlobalId ? { [clmGlobalId]: clmFields } : {},
- embeddings_fields: emGlobalId ? { [emGlobalId]: emFields } : {},
- completions_model_provider_id: clmGlobalId,
- send_with_shift_enter: sendWse
- };
- updateRequest = minifyUpdate(server.config, updateRequest);
- updateRequest.last_read = server.config.last_read;
-
- setSaving(true);
- try {
- await apiKeysAlert.clear();
- await AiService.updateConfig(updateRequest);
- } catch (e) {
- console.error(e);
- const msg =
- e instanceof Error || typeof e === 'string'
- ? e.toString()
- : 'An unknown error occurred. Check the console for more details.';
- alert.show('error', msg);
- return;
- } finally {
- setSaving(false);
- }
- await server.refetchAll();
- alert.show('success', 'Settings saved successfully.');
- };
-
- if (server.state === ServerInfoState.Loading) {
- return (
-
-
-
- );
- }
-
- if (server.state === ServerInfoState.Error) {
- return (
-
-
- {server.error ||
- 'An unknown error occurred. Check the console for more details.'}
-
-
- );
- }
-
return (
- {/* Chat language model section */}
-
+ This section shows the secrets set in the .env file at the
+ workspace root. For most chat models, an API key secret in{' '}
+ .env is required for Jupyternaut to reply in the chat. See
+ the{' '}
+
+ documentation
+ {' '}
+ for information on which API key is required for your model provider.
+
+
+ Click "Add secret" to add a secret to the .env file.
+ Secrets can also be updated by editing the .env file
+ directly in JupyterLab.
+
+ The secrets below are set by the environment variables and the
+ traitlets configuration passed to the server process. These secrets
+ can only be changed either upon restarting the server or by
+ contacting your server administrator.
+