diff --git a/.gitignore b/.gitignore index d156b06d5..d90ad5b40 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,6 @@ packages/**/_version.py # Ignore local chat files & local .jupyter/ dir *.chat .jupyter/ + +# Ignore secrets in '.env' +.env diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 1ef341e22..750960332 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -762,7 +762,7 @@ We currently support the following language model providers: To configure a default model you can use the IPython `%config` magic: ```python -%config AiMagics.default_language_model = "anthropic:claude-v1.2" +%config AiMagics.initial_language_model = "anthropic:claude-v1.2" ``` Then subsequent magics can be invoked without typing in the model: @@ -772,10 +772,10 @@ Then subsequent magics can be invoked without typing in the model: Write a poem about C++. ``` -You can configure the default model for all notebooks by specifying `c.AiMagics.default_language_model` tratilet in `ipython_config.py`, for example: +You can configure the default model for all notebooks by specifying `c.AiMagics.initial_language_model` tratilet in `ipython_config.py`, for example: ```python -c.AiMagics.default_language_model = "anthropic:claude-v1.2" +c.AiMagics.initial_language_model = "anthropic:claude-v1.2" ``` The location of `ipython_config.py` file is documented in [IPython configuration reference](https://ipython.readthedocs.io/en/stable/config/intro.html). @@ -965,18 +965,18 @@ produced the following Python error: Write a new version of this code that does not produce that error. ``` -As a shortcut for explaining errors, you can use the `%ai error` command, which will explain the most recent error using the model of your choice. +As a shortcut for explaining and fixing errors, you can use the `%ai fix` command, which will explain the most recent error using the model of your choice. ``` -%ai error anthropic:claude-v1.2 +%ai fix anthropic:claude-v1.2 ``` ### Creating and managing aliases -You can create an alias for a model using the `%ai register` command. For example, the command: +You can create an alias for a model using the `%ai alias` command. For example, the command: ``` -%ai register claude anthropic:claude-v1.2 +%ai alias claude anthropic:claude-v1.2 ``` will register the alias `claude` as pointing to the `anthropic` provider's `claude-v1.2` model. You can then use this alias as you would use any other model name: @@ -1001,10 +1001,10 @@ prompt = PromptTemplate( chain = LLMChain(llm=llm, prompt=prompt) ``` -… and then use `%ai register` to give it a name: +… and then use `%ai alias` to give it a name: ``` -%ai register companyname chain +%ai alias companyname chain ``` You can change an alias's target using the `%ai update` command: @@ -1013,10 +1013,10 @@ You can change an alias's target using the `%ai update` command: %ai update claude anthropic:claude-instant-v1.0 ``` -You can delete an alias using the `%ai delete` command: +You can delete an alias using the `%ai dealias` command: ``` -%ai delete claude +%ai dealias claude ``` You can see a list of all aliases by running the `%ai list` command. @@ -1103,7 +1103,7 @@ the selections they make in the settings panel will take precedence over these v Specify default language model ```bash -jupyter lab --AiExtension.default_language_model=bedrock-chat:anthropic.claude-v2 +jupyter lab --AiExtension.initial_language_model=bedrock-chat:anthropic.claude-v2 ``` Specify default embedding model diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index b535d4a88..9d0529e0b 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -1,89 +1,20 @@ +from __future__ import annotations + from typing import TYPE_CHECKING -from ._import_utils import import_attr as _import_attr from ._version import __version__ if TYPE_CHECKING: - # same as dynamic imports but understood by mypy - from .embedding_providers import ( - BaseEmbeddingsProvider, - GPT4AllEmbeddingsProvider, - HfHubEmbeddingsProvider, - QianfanEmbeddingsEndpointProvider, - ) - from .exception import store_exception - from .magics import AiMagics - from .providers import ( - AI21Provider, - BaseProvider, - GPT4AllProvider, - HfHubProvider, - QianfanProvider, - TogetherAIProvider, - ) -else: - _exports_by_module = { - "embedding_providers": [ - "BaseEmbeddingsProvider", - "GPT4AllEmbeddingsProvider", - "HfHubEmbeddingsProvider", - "QianfanEmbeddingsEndpointProvider", - ], - "exception": ["store_exception"], - "magics": ["AiMagics"], - # expose model providers on the package root - "providers": [ - "AI21Provider", - "BaseProvider", - "GPT4AllProvider", - "HfHubProvider", - "QianfanProvider", - "TogetherAIProvider", - ], - } + from IPython.core.interactiveshell import InteractiveShell - _modules_by_export = { - import_name: module - for module, imports in _exports_by_module.items() - for import_name in imports - } - - def __getattr__(export_name: str) -> object: - module_name = _modules_by_export.get(export_name) - result = _import_attr(export_name, module_name, __spec__.parent) - globals()[export_name] = result - return result +def load_ipython_extension(ipython: InteractiveShell): + from .exception import store_exception + from .magics import AiMagics -def load_ipython_extension(ipython): - ipython.register_magics(__getattr__("AiMagics")) - ipython.set_custom_exc((BaseException,), __getattr__("store_exception")) + ipython.register_magics(AiMagics) + ipython.set_custom_exc((BaseException,), store_exception) -def unload_ipython_extension(ipython): +def unload_ipython_extension(ipython: InteractiveShell): ipython.set_custom_exc((BaseException,), ipython.CustomTB) - - -# required to preserve backward compatibility with `from jupyter_ai_magics import *` -__all__ = [ - "__version__", - "load_ipython_extension", - "unload_ipython_extension", - "BaseEmbeddingsProvider", - "GPT4AllEmbeddingsProvider", - "HfHubEmbeddingsProvider", - "QianfanEmbeddingsEndpointProvider", - "store_exception", - "AiMagics", - "AI21Provider", - "BaseProvider", - "GPT4AllProvider", - "HfHubProvider", - "QianfanProvider", - "TogetherAIProvider", -] - - -def __dir__(): - # Allows more editors (e.g. IPython) to complete on `jupyter_ai_magics.` - return list(__all__) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/_import_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/_import_utils.py deleted file mode 100644 index f251e452c..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/_import_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -MIT License - -Copyright (c) LangChain, Inc. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" - -from importlib import import_module -from typing import Union - - -def import_attr( - attr_name: str, - module_name: Union[str, None], - package: Union[str, None], -) -> object: - """Import an attribute from a module located in a package. - - This utility function is used in custom __getattr__ methods within __init__.py - files to dynamically import attributes. - - Args: - attr_name: The name of the attribute to import. - module_name: The name of the module to import from. If None, the attribute - is imported from the package itself. - package: The name of the package where the module is located. - """ - if module_name == "__module__" or module_name is None: - try: - result = import_module(f".{attr_name}", package=package) - except ModuleNotFoundError: - msg = f"module '{package!r}' has no attribute {attr_name!r}" - raise AttributeError(msg) from None - else: - try: - module = import_module(f".{module_name}", package=package) - except ModuleNotFoundError: - msg = f"module '{package!r}.{module_name!r}' not found" - raise ImportError(msg) from None - result = getattr(module, attr_name) - return result diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py deleted file mode 100644 index e34a31aa0..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py +++ /dev/null @@ -1,10 +0,0 @@ -MODEL_ID_ALIASES = { - "gpt2": "huggingface_hub:gpt2", - "gpt3": "openai:davinci-002", - "chatgpt": "openai-chat:gpt-3.5-turbo", - "gpt4": "openai-chat:gpt-4", - "ernie-bot": "qianfan:ERNIE-Bot", - "ernie-bot-4": "qianfan:ERNIE-Bot-4", - "titan": "bedrock:amazon.titan-tg1-large", - "openrouter-claude": "openrouter:anthropic/claude-3.5-sonnet:beta", -} diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/base_provider.py b/packages/jupyter-ai-magics/jupyter_ai_magics/base_provider.py deleted file mode 100644 index 8884aa15d..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/base_provider.py +++ /dev/null @@ -1,479 +0,0 @@ -import asyncio -import functools -from collections.abc import AsyncIterator, Coroutine -from concurrent.futures import ThreadPoolExecutor -from types import MappingProxyType -from typing import ( - Any, - ClassVar, - Literal, - Optional, - Union, -) - -from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, - PromptTemplate, - SystemMessagePromptTemplate, -) -from langchain.schema import LLMResult -from langchain.schema.output_parser import StrOutputParser -from langchain.schema.runnable import Runnable -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.language_models.llms import BaseLLM -from pydantic import BaseModel, ConfigDict - -from . import completion_utils as completion -from .models.completion import ( - InlineCompletionList, - InlineCompletionReply, - InlineCompletionRequest, - InlineCompletionStreamChunk, -) - -CHAT_SYSTEM_PROMPT = """ -You are Jupyternaut, a conversational assistant living in JupyterLab to help users. -You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. -You are talkative and you provide lots of specific details from the foundation model's context. -You may use Markdown to format your response. -If your response includes code, they must be enclosed in Markdown fenced code blocks (with triple backticks before and after). -If your response includes mathematical notation, they must be expressed in LaTeX markup and enclosed in LaTeX delimiters. -All dollar quantities (of USD) must be formatted in LaTeX, with the `$` symbol escaped by a single backslash `\\`. -- Example prompt: `If I have \\\\$100 and spend \\\\$20, how much money do I have left?` -- **Correct** response: `You have \\(\\$80\\) remaining.` -- **Incorrect** response: `You have $80 remaining.` -If you do not know the answer to a question, answer truthfully by responding that you do not know. -The following is a friendly conversation between you and a human. -""".strip() - -CHAT_DEFAULT_TEMPLATE = """ -{% if context %} -Context: -{{context}} - -{% endif %} -Current conversation: -{{history}} -Human: {{input}} -AI:""" - -HUMAN_MESSAGE_TEMPLATE = """ -{% if context %} -Context: -{{context}} - -{% endif %} -{{input}} -""" - -COMPLETION_SYSTEM_PROMPT = """ -You are an application built to provide helpful code completion suggestions. -You should only produce code. Keep comments to minimum, use the -programming language comment syntax. Produce clean code. -The code is written in JupyterLab, a data analysis and code development -environment which can execute code extended with additional syntax for -interactive features, such as magics. -""".strip() - -# only add the suffix bit if present to save input tokens/computation time -COMPLETION_DEFAULT_TEMPLATE = """ -The document is called `{{filename}}` and written in {{language}}. -{% if suffix %} -The code after the completion request is: - -``` -{{suffix}} -``` -{% endif %} - -Complete the following code: - -``` -{{prefix}}""" - - -class EnvAuthStrategy(BaseModel): - """ - Describes a provider that uses a single authentication token, which is - passed either as an environment variable or as a keyword argument. - """ - - type: Literal["env"] = "env" - - name: str - """The name of the environment variable, e.g. `'ANTHROPIC_API_KEY'`.""" - - keyword_param: Optional[str] = None - """ - If unset (default), the authentication token is provided as a keyword - argument with the parameter equal to the environment variable name in - lowercase. If set to some string `k`, the authentication token will be - passed using the keyword parameter `k`. - """ - - -class MultiEnvAuthStrategy(BaseModel): - """Require multiple auth tokens via multiple environment variables.""" - - type: Literal["multienv"] = "multienv" - names: list[str] - - -class AwsAuthStrategy(BaseModel): - """Require AWS authentication via Boto3""" - - type: Literal["aws"] = "aws" - - -AuthStrategy = Optional[ - Union[ - EnvAuthStrategy, - MultiEnvAuthStrategy, - AwsAuthStrategy, - ] -] - - -class Field(BaseModel): - key: str - label: str - # "text" accepts any text - format: Literal["json", "jsonpath", "text"] - - -class TextField(Field): - type: Literal["text"] = "text" - - -class MultilineTextField(Field): - type: Literal["text-multiline"] = "text-multiline" - - -class IntegerField(BaseModel): - type: Literal["integer"] = "integer" - key: str - label: str - - -Field = Union[TextField, MultilineTextField, IntegerField] - - -class BaseProvider(BaseModel): - # pydantic v2 model config - # upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra - model_config = ConfigDict(extra="allow") - - # - # class attrs - # - id: ClassVar[str] = ... - """ID for this provider class.""" - - name: ClassVar[str] = ... - """User-facing name of this provider.""" - - models: ClassVar[list[str]] = ... - """List of supported models by their IDs. For registry providers, this will - be just ["*"].""" - - help: ClassVar[Optional[str]] = None - """Text to display in lieu of a model list for a registry provider that does - not provide a list of models.""" - - model_id_key: ClassVar[Optional[str]] = None - """ - Optional field which specifies the key under which `model_id` is passed to - the parent LangChain class. - - If unset, this defaults to "model_id". - """ - - model_id_label: ClassVar[Optional[str]] = None - """ - Optional field which sets the label shown in the UI allowing users to - select/type a model ID. - - If unset, the label shown in the UI defaults to "Model ID". - """ - - pypi_package_deps: ClassVar[list[str]] = [] - """List of PyPi package dependencies.""" - - auth_strategy: ClassVar[AuthStrategy] = None - """Authentication/authorization strategy. Declares what credentials are - required to use this model provider. Generally should not be `None`.""" - - registry: ClassVar[bool] = False - """Whether this provider is a registry provider.""" - - fields: ClassVar[list[Field]] = [] - """User inputs expected by this provider when initializing it. Each `Field` `f` - should be passed in the constructor as a keyword argument, keyed by `f.key`.""" - - manages_history: ClassVar[bool] = False - """Whether this provider manages its own conversation history upstream. If - set to `True`, Jupyter AI will not pass the chat history to this provider - when invoked.""" - - unsupported_slash_commands: ClassVar[set] = set() - """ - A set of slash commands unsupported by this provider. Unsupported slash - commands are not shown in the help message, and cannot be used while this - provider is selected. - """ - - server_settings: ClassVar[Optional[MappingProxyType]] = None - """Settings passed on from jupyter-ai package. - - The same server settings are shared between all providers. - Providers are not allowed to mutate this dictionary. - """ - - @classmethod - def chat_models(self): - """Models which are suitable for chat.""" - return self.models - - @classmethod - def completion_models(self): - """Models which are suitable for completions.""" - return self.models - - # - # instance attrs - # - model_id: str - prompt_templates: dict[str, PromptTemplate] - """Prompt templates for each output type. Can be overridden with - `update_prompt_template`. The function `prompt_template`, in the base class, - refers to this.""" - - def __init__(self, *args, **kwargs): - try: - assert kwargs["model_id"] - except: - raise AssertionError( - "model_id was not specified. Please specify it as a keyword argument." - ) - - model_kwargs = {} - if self.__class__.model_id_key != "model_id": - model_kwargs[self.__class__.model_id_key] = kwargs["model_id"] - - model_kwargs["prompt_templates"] = { - "code": PromptTemplate.from_template( - "{prompt}\n\nProduce output as source code only, " - "with no text or explanation before or after it." - ), - "html": PromptTemplate.from_template( - "{prompt}\n\nProduce output in HTML format only, " - "with no markup before or afterward." - ), - "image": PromptTemplate.from_template( - "{prompt}\n\nProduce output as an image only, " - "with no text before or after it." - ), - "markdown": PromptTemplate.from_template( - "{prompt}\n\nProduce output in markdown format only." - ), - "md": PromptTemplate.from_template( - "{prompt}\n\nProduce output in markdown format only." - ), - "math": PromptTemplate.from_template( - "{prompt}\n\nProduce output in LaTeX format only, " - "with $$ at the beginning and end." - ), - "json": PromptTemplate.from_template( - "{prompt}\n\nProduce output in JSON format only, " - "with nothing before or after it." - ), - "text": PromptTemplate.from_template("{prompt}"), # No customization - } - super().__init__(*args, **kwargs, **model_kwargs) - - async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]: - """ - Calls self._call() asynchronously in a separate thread for providers - without an async implementation. Requires the event loop to be running. - """ - executor = ThreadPoolExecutor(max_workers=1) - loop = asyncio.get_running_loop() - _call_with_args = functools.partial(self._call, *args, **kwargs) - return await loop.run_in_executor(executor, _call_with_args) - - async def _generate_in_executor( - self, *args, **kwargs - ) -> Coroutine[Any, Any, LLMResult]: - """ - Calls self._generate() asynchronously in a separate thread for providers - without an async implementation. Requires the event loop to be running. - """ - executor = ThreadPoolExecutor(max_workers=1) - loop = asyncio.get_running_loop() - _call_with_args = functools.partial(self._generate, *args, **kwargs) - return await loop.run_in_executor(executor, _call_with_args) - - @classmethod - def is_api_key_exc(cls, _: Exception): - """ - Determine if the exception is an API key error. Can be implemented by subclasses. - """ - return False - - def update_prompt_template(self, format: str, template: str): - """ - Changes the class-level prompt template for a given format. - """ - self.prompt_templates[format] = PromptTemplate.from_template(template) - - def get_prompt_template(self, format) -> PromptTemplate: - """ - Produce a prompt template suitable for use with a particular model, to - produce output in a desired format. - """ - - if format in self.prompt_templates: - return self.prompt_templates[format] - else: - return self.prompt_templates["text"] # Default to plain format - - def get_chat_prompt_template(self) -> PromptTemplate: - """ - Produce a prompt template optimised for chat conversation. - The template should take two variables: history and input. - """ - name = self.__class__.name - if self.is_chat_provider: - return ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template( - CHAT_SYSTEM_PROMPT - ).format(provider_name=name, local_model_id=self.model_id), - MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template( - HUMAN_MESSAGE_TEMPLATE, - template_format="jinja2", - ), - ] - ) - else: - return PromptTemplate( - input_variables=["history", "input", "context"], - template=CHAT_SYSTEM_PROMPT.format( - provider_name=name, local_model_id=self.model_id - ) - + "\n\n" - + CHAT_DEFAULT_TEMPLATE, - template_format="jinja2", - ) - - def get_completion_prompt_template(self) -> PromptTemplate: - """ - Produce a prompt template optimised for inline code or text completion. - The template should take variables: prefix, suffix, language, filename. - """ - if self.is_chat_provider: - return ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(COMPLETION_SYSTEM_PROMPT), - HumanMessagePromptTemplate.from_template( - COMPLETION_DEFAULT_TEMPLATE, template_format="jinja2" - ), - ] - ) - else: - return PromptTemplate( - input_variables=["prefix", "suffix", "language", "filename"], - template=COMPLETION_SYSTEM_PROMPT - + "\n\n" - + COMPLETION_DEFAULT_TEMPLATE, - template_format="jinja2", - ) - - @property - def is_chat_provider(self): - return isinstance(self, BaseChatModel) - - @property - def allows_concurrency(self): - return True - - @property - def _supports_sync_streaming(self): - if self.is_chat_provider: - return not (self.__class__._stream is BaseChatModel._stream) - else: - return not (self.__class__._stream is BaseLLM._stream) - - @property - def _supports_async_streaming(self): - if self.is_chat_provider: - return not (self.__class__._astream is BaseChatModel._astream) - else: - return not (self.__class__._astream is BaseLLM._astream) - - @property - def supports_streaming(self): - return self._supports_sync_streaming or self._supports_async_streaming - - async def generate_inline_completions( - self, request: InlineCompletionRequest - ) -> InlineCompletionReply: - chain = self._create_completion_chain() - model_arguments = completion.template_inputs_from_request(request) - suggestion = await chain.ainvoke(input=model_arguments) - suggestion = completion.post_process_suggestion(suggestion, request) - return InlineCompletionReply( - list=InlineCompletionList(items=[{"insertText": suggestion}]), - reply_to=request.number, - ) - - async def stream_inline_completions( - self, request: InlineCompletionRequest - ) -> AsyncIterator[InlineCompletionStreamChunk]: - chain = self._create_completion_chain() - token = completion.token_from_request(request, 0) - model_arguments = completion.template_inputs_from_request(request) - suggestion = processed_suggestion = "" - - # send an incomplete `InlineCompletionReply`, indicating to the - # client that LLM output is about to streamed across this connection. - yield InlineCompletionReply( - list=InlineCompletionList( - items=[ - { - # insert text starts empty as we do not pre-generate any part - "insertText": "", - "isIncomplete": True, - "token": token, - } - ] - ), - reply_to=request.number, - ) - - async for fragment in chain.astream(input=model_arguments): - suggestion += fragment - processed_suggestion = completion.post_process_suggestion( - suggestion, request - ) - yield InlineCompletionStreamChunk( - type="stream", - response={"insertText": processed_suggestion, "token": token}, - reply_to=request.number, - done=False, - ) - - # finally, send a message confirming that we are done - yield InlineCompletionStreamChunk( - type="stream", - response={"insertText": processed_suggestion, "token": token}, - reply_to=request.number, - done=True, - ) - - def _create_completion_chain(self) -> Runnable: - prompt_template = self.get_completion_prompt_template() - return prompt_template | self | StrOutputParser() diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py deleted file mode 100644 index 38053635f..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import ClassVar, Optional - -from jupyter_ai_magics.base_provider import ( - AuthStrategy, - EnvAuthStrategy, - Field, - MultiEnvAuthStrategy, -) -from langchain_community.embeddings import ( - GPT4AllEmbeddings, - HuggingFaceHubEmbeddings, - QianfanEmbeddingsEndpoint, -) -from pydantic import BaseModel, ConfigDict - - -class BaseEmbeddingsProvider(BaseModel): - """Base class for embedding providers""" - - # pydantic v2 model config - # upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra - model_config = ConfigDict(extra="allow") - - id: ClassVar[str] = ... - """ID for this provider class.""" - - name: ClassVar[str] = ... - """User-facing name of this provider.""" - - models: ClassVar[list[str]] = ... - """List of supported models by their IDs. For registry providers, this will - be just ["*"].""" - - help: ClassVar[Optional[str]] = None - """Text to display in lieu of a model list for a registry provider that does - not provide a list of models.""" - - model_id_key: ClassVar[str] = ... - """Kwarg expected by the upstream LangChain provider.""" - - pypi_package_deps: ClassVar[list[str]] = [] - """List of PyPi package dependencies.""" - - auth_strategy: ClassVar[AuthStrategy] = None - """Authentication/authorization strategy. Declares what credentials are - required to use this model provider. Generally should not be `None`.""" - - model_id: str - - registry: ClassVar[bool] = False - """Whether this provider is a registry provider.""" - - fields: ClassVar[list[Field]] = [] - """Fields expected by this provider in its constructor. Each `Field` `f` - should be passed as a keyword argument, keyed by `f.key`.""" - - def __init__(self, *args, **kwargs): - try: - assert kwargs["model_id"] - except: - raise AssertionError( - "model_id was not specified. Please specify it as a keyword argument." - ) - - model_kwargs = {} - if self.__class__.model_id_key != "model_id": - model_kwargs[self.__class__.model_id_key] = kwargs["model_id"] - - super().__init__(*args, **kwargs, **model_kwargs) - - -class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings): - id = "huggingface_hub" - name = "Hugging Face Hub" - models = ["*"] - model_id_key = "repo_id" - help = ( - "See [https://huggingface.co/docs/chat-ui/en/configuration/embeddings](https://huggingface.co/docs/chat-ui/en/configuration/embeddings) for reference. " - "Pass an embedding model's name; for example, `sentence-transformers/all-MiniLM-L6-v2`." - ) - # ipywidgets needed to suppress tqdm warning - # https://stackoverflow.com/questions/67998191 - # tqdm is a dependency of huggingface_hub - pypi_package_deps = ["huggingface_hub", "ipywidgets"] - auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") - registry = True - - -class GPT4AllEmbeddingsProvider(BaseEmbeddingsProvider, GPT4AllEmbeddings): - def __init__(self, **kwargs): - from gpt4all import GPT4All - - model_name = kwargs.get("model_id").split(":")[-1] - - # GPT4AllEmbeddings doesn't allow any kwargs at the moment - # This will cause the class to start downloading the model - # if the model file is not present. Calling retrieve_model - # here will throw an exception if the file is not present. - GPT4All.retrieve_model(model_name=model_name, allow_download=False) - - kwargs["allow_download"] = False - super().__init__(**kwargs) - - id = "gpt4all" - name = "GPT4All Embeddings" - models = ["all-MiniLM-L6-v2-f16"] - model_id_key = "model_id" - pypi_package_deps = ["gpt4all"] - - -class QianfanEmbeddingsEndpointProvider( - BaseEmbeddingsProvider, QianfanEmbeddingsEndpoint -): - id = "qianfan" - name = "ERNIE-Bot" - models = ["ERNIE-Bot", "ERNIE-Bot-4"] - model_id_key = "model" - pypi_package_deps = ["qianfan"] - auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"]) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 5f94c2ecf..88ba32feb 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -1,28 +1,22 @@ import base64 import json -import keyword -import os import re import sys import warnings -from typing import Optional +from typing import Any, Optional import click +import litellm import traitlets from IPython.core.magic import Magics, line_cell_magic, magics_class from IPython.display import HTML, JSON, Markdown, Math -from jupyter_ai_magics.aliases import MODEL_ID_ALIASES -from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers -from langchain.chains import LLMChain -from langchain.schema import HumanMessage -from langchain_core.messages import AIMessage +from jupyter_ai.model_providers.model_list import CHAT_MODELS from ._version import __version__ -from .base_provider import BaseProvider from .parsers import ( CellArgs, DeleteArgs, - ErrorArgs, + FixArgs, HelpArgs, ListArgs, RegisterArgs, @@ -90,7 +84,7 @@ def _repr_mimebundle_(self, include=None, exclude=None): To see a list of models you can use, run `%ai list`""" -AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"} +AI_COMMANDS = {"dealias", "fix", "help", "list", "alias", "update"} # Strings for listing providers and models # Avoid composing strings, to make localization easier in the future @@ -125,8 +119,10 @@ class CellMagicError(BaseException): @magics_class class AiMagics(Magics): - aliases = traitlets.Dict( - default_value=MODEL_ID_ALIASES, + # TODO: rename this to initial_aliases + # This should only set the "starting set" of aliases + initial_aliases = traitlets.Dict( + default_value={}, value_trait=traitlets.Unicode(), key_trait=traitlets.Unicode(), help="""Aliases for model identifiers. @@ -137,7 +133,7 @@ class AiMagics(Magics): config=True, ) - default_language_model = traitlets.Unicode( + initial_language_model = traitlets.Unicode( default_value=None, allow_none=True, help="""Default language model to use, as string in the format @@ -155,17 +151,20 @@ class AiMagics(Magics): config=True, ) + transcript: list[dict[str, str]] + """ + The conversation history as a list of messages. Each message is a simple + dictionary with the following structure: + + - `"role"`: `"user"`, `"assistant"`, or `"system"` + - `"content"`: the content of the message + """ + def __init__(self, shell): super().__init__(shell) self.transcript = [] - # suppress warning when using old Anthropic provider - warnings.filterwarnings( - "ignore", - message="This Anthropic LLM is deprecated. Please use " - "`from langchain.chat_models import ChatAnthropic` instead", - ) - + # TODO: check if this is necessary # suppress warning about our exception handler warnings.filterwarnings( "ignore", @@ -175,304 +174,176 @@ def __init__(self, shell): "show full tracebacks.", ) - self.providers = get_lm_providers() - + # TODO: use LiteLLM aliases to provide this + # https://docs.litellm.ai/docs/completion/model_alias # initialize a registry of custom model/chain names - self.custom_model_registry = self.aliases - - def _ai_bulleted_list_models_for_provider(self, provider_id, Provider): - output = "" - if len(Provider.models) == 1 and Provider.models[0] == "*": - if Provider.help is None: - output += f"* {PROVIDER_NO_MODELS}\n" - else: - output += f"* {Provider.help}\n" - else: - for model_id in Provider.models: - output += f"* {provider_id}:{model_id}\n" - output += "\n" # End of bulleted list - - return output + self.aliases = self.initial_aliases.copy() - def _ai_inline_list_models_for_provider(self, provider_id, Provider): - output = "" - - # Is the required environment variable set? - def _ai_env_status_for_provider_markdown(self, provider_id): - na_message = "Not applicable. | " + NA_MESSAGE - - if ( - provider_id not in self.providers - or self.providers[provider_id].auth_strategy == None - ): - return na_message # No emoji - - not_set_title = ENV_NOT_SET - set_title = ENV_SET - env_status_ok = False - - auth_strategy = self.providers[provider_id].auth_strategy - if auth_strategy.type == "env": - var_name = auth_strategy.name - env_var_display = f"`{var_name}`" - env_status_ok = var_name in os.environ - elif auth_strategy.type == "multienv": - # Check multiple environment variables - var_names = self.providers[provider_id].auth_strategy.names - formatted_names = [f"`{name}`" for name in var_names] - env_var_display = ", ".join(formatted_names) - env_status_ok = all(var_name in os.environ for var_name in var_names) - not_set_title = MULTIENV_NOT_SET - set_title = MULTIENV_SET - else: # No environment variables - return na_message - - output = f"{env_var_display} | " - if env_status_ok: - output += f'' - else: - output += f'' - - return output - - def _ai_env_status_for_provider_text(self, provider_id): - # only handle providers with "env" or "multienv" auth strategy - auth_strategy = getattr(self.providers[provider_id], "auth_strategy", None) - if not auth_strategy or ( - auth_strategy.type != "env" and auth_strategy.type != "multienv" - ): - return "" - - prefix = ENV_REQUIRES if auth_strategy.type == "env" else MULTIENV_REQUIRES - envvars = ( - [auth_strategy.name] - if auth_strategy.type == "env" - else auth_strategy.names[:] - ) - - for i in range(len(envvars)): - envvars[i] += " (set)" if envvars[i] in os.environ else " (not set)" - - return prefix + " " + ", ".join(envvars) + "\n" - - # Is this a name of a Python variable that can be called as a LangChain chain? - def _is_langchain_chain(self, name): - # Reserved word in Python? - if keyword.iskeyword(name): - return False + args = line_magic_parser( + raw_args, + prog_name=r"%ai", + standalone_mode=False, + default_map={"fix": default_map}, + ) + except Exception as e: + if "model_id" in str(e) and "string_type" in str(e): + error_msg = "No Model ID entered, please enter it in the following format: `%%ai `" + print(error_msg, file=sys.stderr) + return + if not args: + print("No valid %ai magics arguments given, run `%ai help` for all options.", file=sys.stderr) + return + raise e + + if args == 0 and self.initial_language_model is None: + # this happens when `--help` is called on the root command, in which + # case we want to exit early. + return - acceptable_name = re.compile("^[a-zA-Z0-9_]+$") - if not acceptable_name.match(name): - return False + # If a value error occurs, don't print the full stacktrace + try: + if args.type == "fix": + return self.handle_fix(args) + if args.type == "help": + return self.handle_help(args) + if args.type == "list": + return self.handle_list(args) + if args.type == "alias": + return self.handle_alias(args) + if args.type == "dealias": + return self.handle_dealias(args) + if args.type == "version": + return self.handle_version(args) + if args.type == "reset": + return self.handle_reset(args) + except ValueError as e: + print(e, file=sys.stderr) + return - ipython = self.shell - return name in ipython.user_ns and isinstance(ipython.user_ns[name], LLMChain) + # hint to the IDE that this object must be of type `CellArgs` + args: CellArgs = args - # Is this an acceptable name for an alias? - def _validate_name(self, register_name): - # A registry name contains ASCII letters, numbers, hyphens, underscores, - # and periods. No other characters, including a colon, are permitted - acceptable_name = re.compile("^[a-zA-Z0-9._-]+$") - if not acceptable_name.match(register_name): - raise ValueError( - "A registry name may contain ASCII letters, numbers, hyphens, underscores, " - + "and periods. No other characters, including a colon, are permitted" + if not cell: + raise CellMagicError( + """To invoke a language model, you must use the `%%ai` + cell magic. The `%ai` line magic is only for use with + subcommands.""" ) - # Initially set or update an alias to a target - def _safely_set_target(self, register_name, target): - # If target is a string, treat this as an alias to another model. - if self._is_langchain_chain(target): - ip = self.shell - self.custom_model_registry[register_name] = ip.user_ns[target] - else: - # Ensure that the destination is properly formatted - if ":" not in target: - raise ValueError( - "Target model must be an LLMChain object or a model name in PROVIDER_ID:MODEL_NAME format" - ) + prompt = cell.strip() - self.custom_model_registry[register_name] = target + return self.run_ai_cell(args, prompt) - def handle_delete(self, args: DeleteArgs): - if args.name in AI_COMMANDS: - raise ValueError( - f"Reserved command names, including {args.name}, cannot be deleted" - ) + def run_ai_cell(self, args: CellArgs, prompt: str): + """ + Handles the `%%ai` cell magic. This is the main method that invokes the + language model. + """ + # Interpolate local variables into prompt. + # For example, if a user runs `a = "hello"` and then runs `%%ai {a}`, it + # should be equivalent to running `%%ai hello`. + ip = self.shell + prompt = prompt.format_map(FormatDict(ip.user_ns)) - if args.name not in self.custom_model_registry: - raise ValueError(f"There is no alias called {args.name}") + # Prepare messages for the model + messages = [] - del self.custom_model_registry[args.name] - output = f"Deleted alias `{args.name}`" - return TextOrMarkdown(output, output) + # Add conversation history if available + if self.transcript: + messages.extend(self.transcript[-2 * self.max_history :]) - def handle_register(self, args: RegisterArgs): - # Existing command names are not allowed - if args.name in AI_COMMANDS: - raise ValueError(f"The name {args.name} is reserved for a command") + # Add current prompt + messages.append({"role": "user", "content": prompt}) - # Existing registered names are not allowed - if args.name in self.custom_model_registry: - raise ValueError( - f"The name {args.name} is already associated with a custom model; " - + "use %ai update to change its target" + # Resolve model_id: check if it's in CHAT_MODELS or an alias + model_id = args.model_id + if model_id not in CHAT_MODELS: + # Check if it's an alias + if model_id in self.aliases: + model_id = self.aliases[model_id] + else: + error_msg = f"Model ID '{model_id}' is not a known model or alias. Run '%ai list' to see available models and aliases." + print(error_msg, file=sys.stderr) # Log to stderr + return + try: + # Call litellm completion + response = litellm.completion( + model=model_id, messages=messages, stream=False ) - # Does the new name match expected format? - self._validate_name(args.name) + # Extract output text from response + output = response.choices[0].message.content - self._safely_set_target(args.name, args.target) - output = f"Registered new alias `{args.name}`" - return TextOrMarkdown(output, output) + # Append exchange to transcript + self._append_exchange(prompt, output) - def handle_update(self, args: UpdateArgs): - if args.name in AI_COMMANDS: - raise ValueError( - f"Reserved command names, including {args.name}, cannot be updated" - ) + # Set model ID in metadata + metadata = {"jupyter_ai_v3": {"model_id": args.model_id}} - if args.name not in self.custom_model_registry: - raise ValueError(f"There is no alias called {args.name}") + # Return output given the format + return self.display_output(output, args.format, metadata) - self._safely_set_target(args.name, args.target) - output = f"Updated target of alias `{args.name}`" - return TextOrMarkdown(output, output) + except Exception as e: + error_msg = f"Error calling language model: {str(e)}" + print(error_msg, file=sys.stderr) + return error_msg - def _ai_list_command_markdown(self, single_provider=None): - output = ( - "| Provider | Environment variable | Set? | Models |\n" - + "|----------|----------------------|------|--------|\n" - ) - if single_provider is not None and single_provider not in self.providers: - return f"There is no model provider with ID `{single_provider}`." - - for provider_id, Provider in self.providers.items(): - if single_provider is not None and provider_id != single_provider: - continue - - output += ( - f"| `{provider_id}` | " - + self._ai_env_status_for_provider_markdown(provider_id) - + " | " - + self._ai_inline_list_models_for_provider(provider_id, Provider) - + " |\n" - ) + def display_output(self, output, display_format, metadata: dict[str, Any]) -> Any: + """ + Returns an IPython 'display object' that determines how an output is + rendered. This is complex, so here are some notes: - # Also list aliases. - if single_provider is None and len(self.custom_model_registry) > 0: - output += ( - "\nAliases and custom commands:\n\n" - + "| Name | Target |\n" - + "|------|--------|\n" - ) - for key, value in self.custom_model_registry.items(): - output += f"| `{key}` | " - if isinstance(value, str): - output += f"`{value}`" - else: - output += "*custom chain*" - - output += " |\n" - - return output - - def _ai_list_command_text(self, single_provider=None): - output = "" - if single_provider is not None and single_provider not in self.providers: - return f"There is no model provider with ID '{single_provider}'." - - for provider_id, Provider in self.providers.items(): - if single_provider is not None and provider_id != single_provider: - continue - - output += ( - f"{provider_id}\n" - + self._ai_env_status_for_provider_text( - provider_id - ) # includes \n if nonblank - + self._ai_bulleted_list_models_for_provider(provider_id, Provider) - ) + - The display object returned is controlled by the `display_format` + argument. See `DISPLAYS_BY_FORMAT` for the list of valid formats. - # Also list aliases. - if single_provider is None and len(self.custom_model_registry) > 0: - output += "\nAliases and custom commands:\n" - for key, value in self.custom_model_registry.items(): - output += f"{key} - " - if isinstance(value, str): - output += value - else: - output += "custom chain" + - In most use-cases, this method returns a `TextOrMarkdown` object. The + reason this exists is because IPython may be run from a terminal shell + (via the `ipython` command) or from a web browser in a Jupyter Notebook. - output += "\n" + - `TextOrMarkdown` shows text when viewed from a command line, and rendered + Markdown when viewed from a web browser. - return output + - See `DISPLAYS_BY_FORMAT` for the list of display objects that can be + returned by `jupyter_ai_magics`. - def handle_error(self, args: ErrorArgs): - no_errors = "There have been no errors since the kernel started." + TODO: Use a string enum to store the list of valid formats. - # Find the most recent error. - ip = self.shell - if "Err" not in ip.user_ns: - return TextOrMarkdown(no_errors, no_errors) - - err = ip.user_ns["Err"] - # Start from the previous execution count - excount = ip.execution_count - 1 - last_error = None - while excount >= 0 and last_error is None: - if excount in err: - last_error = err[excount] - else: - excount = excount - 1 - - if last_error is None: - return TextOrMarkdown(no_errors, no_errors) - - prompt = f"Explain the following error:\n\n{last_error}" - # Set CellArgs based on ErrorArgs - values = args.model_dump() - values["type"] = "root" - cell_args = CellArgs(**values) - - return self.run_ai_cell(cell_args, prompt) - - def _append_exchange(self, prompt: str, output: str): - """Appends a conversational exchange between user and an OpenAI Chat - model to a transcript that will be included in future exchanges.""" - self.transcript.append(HumanMessage(prompt)) - self.transcript.append(AIMessage(output)) - - def _decompose_model_id(self, model_id: str): - """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" - # custom_model_registry maps keys to either a model name (a string) or an LLMChain. - # If this is an alias to another model, expand the full name of the model. - if model_id in self.custom_model_registry and isinstance( - self.custom_model_registry[model_id], str - ): - model_id = self.custom_model_registry[model_id] - - return decompose_model_id(model_id, self.providers) - - def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: - """Returns the model provider ID and class for a model ID. Returns None if indeterminate.""" - if provider_id is None or provider_id not in self.providers: - return None - - return self.providers[provider_id] - - def display_output(self, output, display_format, md): + TODO: What is the shared type that all display objects implement? We + implement `_repr_mime_()` but that doesn't seem to be implemented on all + display objects. So the return type is `Any` for now. + """ # build output display DisplayClass = DISPLAYS_BY_FORMAT[display_format] @@ -484,7 +355,7 @@ def display_output(self, output, display_format, md): output = re.sub(r"\n```$", "", output) self.shell.set_next_input(output, replace=False) return HTML( - "AI generated code inserted below ⬇️", metadata=md + "AI generated code inserted below ⬇️", metadata=metadata ) if DisplayClass is None: @@ -492,194 +363,176 @@ def display_output(self, output, display_format, md): if display_format == "json": # JSON display expects a dict, not a JSON string output = json.loads(output) - output_display = DisplayClass(output, metadata=md) + output_display = DisplayClass(output, metadata=metadata) # finally, display output display return output_display - def handle_help(self, _: HelpArgs): + def _append_exchange(self, prompt: str, output: str): + """ + Appends an exchange between a user and a language model to + `self.transcript`. This transcript will be included in future `%ai` + calls to preserve conversation history. + """ + self.transcript.append({"role": "user", "content": prompt}) + self.transcript.append({"role": "assistant", "content": output}) + # Keep only the most recent `self.max_history * 2` messages + max_len = self.max_history * 2 + if len(self.transcript) > max_len: + self.transcript = self.transcript[-max_len:] + + def handle_help(self, _: HelpArgs) -> None: + """ + Handles `%ai help`. Prints a help message via `click.echo()`. + """ # The line parser's help function prints both cell and line help - with click.Context(line_magic_parser, info_name="%ai") as ctx: + with click.Context(line_magic_parser, info_name=r"%ai") as ctx: click.echo(line_magic_parser.get_help(ctx)) - def handle_list(self, args: ListArgs): - return TextOrMarkdown( - self._ai_list_command_text(args.provider_id), - self._ai_list_command_markdown(args.provider_id), - ) - - def handle_version(self, args: VersionArgs): - return __version__ - - def handle_reset(self, args: ResetArgs): - self.transcript = [] + def handle_dealias(self, args: DeleteArgs) -> TextOrMarkdown: + """ + Handles `%ai dealias`. Deletes a model alias. + """ - def run_ai_cell(self, args: CellArgs, prompt: str): - provider_id, local_model_id = self._decompose_model_id(args.model_id) - - # If this is a custom chain, send the message to the custom chain. - if args.model_id in self.custom_model_registry and isinstance( - self.custom_model_registry[args.model_id], LLMChain - ): - # Get the output, either as raw text or as the contents of the 'text' key of a dict - invoke_output = self.custom_model_registry[args.model_id].invoke(prompt) - if isinstance(invoke_output, dict): - invoke_output = invoke_output.get("text") - - return self.display_output( - invoke_output, - args.format, - {"jupyter_ai": {"custom_chain_id": args.model_id}}, + if args.name in AI_COMMANDS: + raise ValueError( + f"Reserved command names, including {args.name}, cannot be deleted" ) - Provider = self._get_provider(provider_id) - if Provider is None: - return TextOrMarkdown( - CANNOT_DETERMINE_MODEL_TEXT.format(args.model_id) - + "\n\n" - + "If you were trying to run a command, run '%ai help' to see a list of commands.", - CANNOT_DETERMINE_MODEL_MARKDOWN.format(args.model_id) - + "\n\n" - + "If you were trying to run a command, run `%ai help` to see a list of commands.", - ) + if args.name not in self.aliases: + raise ValueError(f"There is no alias called {args.name}") - # validate presence of authn credentials - auth_strategy = self.providers[provider_id].auth_strategy - if auth_strategy: - if auth_strategy.type == "env" and auth_strategy.name not in os.environ: - raise OSError( - f"Authentication environment variable {auth_strategy.name} is not set.\n" - f"An authentication token is required to use models from the {Provider.name} provider.\n" - f"Please specify it via `%env {auth_strategy.name}=token`. " - ) from None - if auth_strategy.type == "multienv": - # Multiple environment variables must be set - missing_vars = [ - var for var in auth_strategy.names if var not in os.environ - ] - raise OSError( - f"Authentication environment variables {missing_vars} are not set.\n" - f"Multiple authentication tokens are required to use models from the {Provider.name} provider.\n" - f"Please specify them all via `%env` commands. " - ) from None - - # configure and instantiate provider - provider_params = {"model_id": local_model_id} - # for SageMaker, validate that required params are specified - if provider_id == "sagemaker-endpoint": - if ( - args.region_name is None - or args.request_schema is None - or args.response_path is None - ): - raise ValueError( - "When using the sagemaker-endpoint provider, you must specify all of " - + "the --region-name, --request-schema, and --response-path options." - ) - provider_params["region_name"] = args.region_name - provider_params["request_schema"] = args.request_schema - provider_params["response_path"] = args.response_path + del self.aliases[args.name] + output = f"Deleted alias `{args.name}`" + return TextOrMarkdown(output, output) - model_parameters = json.loads(args.model_parameters) + def handle_reset(self, args: ResetArgs) -> None: + """ + Handles `%ai reset`. Clears the history. + """ + self.transcript = [] - provider = Provider(**provider_params, **model_parameters) + def handle_fix(self, args: FixArgs) -> Any: + """ + Handles `%ai fix`. Meant to provide fixes for any exceptions raised in + the kernel while running cells. - # Apply a prompt template. - prompt = provider.get_prompt_template(args.format).format(prompt=prompt) + TODO: annotate a valid return type when we find a type that is shared by + all display objects. + """ + no_errors_message = "There have been no errors since the kernel started." - # interpolate user namespace into prompt + # Find the most recent error. ip = self.shell - prompt = prompt.format_map(FormatDict(ip.user_ns)) + if "Err" not in ip.user_ns: + return TextOrMarkdown(no_errors_message, no_errors_message) - context = self.transcript[-2 * self.max_history :] if self.max_history else [] - if provider.is_chat_provider: - result = provider.generate([[*context, HumanMessage(content=prompt)]]) - else: - # generate output from model via provider - if context: - inputs = "\n\n".join( - [ - ( - f"AI: {message.content}" - if message.type == "ai" - else f"{message.type.title()}: {message.content}" - ) - for message in context + [HumanMessage(content=prompt)] - ] - ) + err = ip.user_ns["Err"] + # Start from the previous execution count + excount = ip.execution_count - 1 + last_error = None + while excount >= 0 and last_error is None: + if excount in err: + last_error = err[excount] else: - inputs = prompt - result = provider.generate([inputs]) - - output = result.generations[0][0].text + excount = excount - 1 - # append exchange to transcript - self._append_exchange(prompt, output) + if last_error is None: + return TextOrMarkdown(no_errors_message, no_errors_message) - md = {"jupyter_ai": {"provider_id": provider_id, "model_id": local_model_id}} + prompt = f"Explain the following error and propose a fix:\n\n{last_error}" + # Set CellArgs based on FixArgs + values = args.model_dump() + values["type"] = "root" + cell_args = CellArgs(**values) + print("I will attempt to explain and fix the error. ") - return self.display_output(output, args.format, md) + return self.run_ai_cell(cell_args, prompt) - @line_cell_magic - def ai(self, line, cell=None): - raw_args = line.split(" ") - default_map = {"model_id": self.default_language_model} - if cell: - args = cell_magic_parser( - raw_args, - prog_name="%%ai", - standalone_mode=False, - default_map={"cell_magic_parser": default_map}, - ) - else: - args = line_magic_parser( - raw_args, - prog_name="%ai", - standalone_mode=False, - default_map={"error": default_map}, - ) + def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown: + """ + Handles `%ai alias`. Adds an alias for a model ID for future calls. + """ + # Existing command names are not allowed + if args.name in AI_COMMANDS: + raise ValueError(f"The name {args.name} is reserved for a command") - if args == 0 and self.default_language_model is None: - # this happens when `--help` is called on the root command, in which - # case we want to exit early. - return + # Store the alias + self.aliases[args.name] = args.target - # If a value error occurs, don't print the full stacktrace - try: - if args.type == "error": - return self.handle_error(args) - if args.type == "help": - return self.handle_help(args) - if args.type == "list": - return self.handle_list(args) - if args.type == "register": - return self.handle_register(args) - if args.type == "delete": - return self.handle_delete(args) - if args.type == "update": - return self.handle_update(args) - if args.type == "version": - return self.handle_version(args) - if args.type == "reset": - return self.handle_reset(args) - except ValueError as e: - print(e, file=sys.stderr) - return + output = f"Registered new alias `{args.name}`" + return TextOrMarkdown(output, output) - # hint to the IDE that this object must be of type `RootArgs` - args: CellArgs = args + def handle_version(self, args: VersionArgs) -> str: + """ + Handles `%ai version`. Returns the current version of + `jupyter_ai_magics`. + """ + return __version__ - if not cell: - raise CellMagicError( - """[0.8+]: To invoke a language model, you must use the `%%ai` - cell magic. The `%ai` line magic is only for use with - subcommands.""" - ) + def handle_list(self, args: ListArgs): + """ + Handles `%ai list`. + - `%ai list` shows all providers by default, and ask the user to run %ai list . + - `%ai list ` shows all models available from one provider. It should also note that the list is not comprehensive, and include a reference to the upstream LiteLLM docs. + - `%ai list all` should list all models. + """ + # Get list of available models from litellm + models = CHAT_MODELS + + # If provider_id is None, only return provider IDs + if getattr(args, 'provider_id', None) is None: + # Extract unique provider IDs from model IDs + provider_ids = set() + for model in models: + if '/' in model: + provider_ids.add(model.split('/')[0]) + + # Format output for both text and markdown + text_output = "Available providers\n\n (Run `%ai list ` to see models for a specific provider)\n\n" + markdown_output = "## Available providers\n\n (Run `%ai list ` to see models for a specific provider)\n\n" + + for provider_id in sorted(provider_ids): + text_output += f"* {provider_id}\n" + markdown_output += f"* `{provider_id}`\n" + + return TextOrMarkdown(text_output, markdown_output) + + elif getattr(args, 'provider_id', None) == 'all': + # Otherwise show all models and aliases + text_output = "All available models\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n" + markdown_output = "## All available models \n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n" + + for model in models: + text_output += f"* {model}\n" + markdown_output += f"* `{model}`\n" + + # Also list any custom aliases + if len(self.aliases) > 0: + text_output += "\nAliases:\n" + markdown_output += "\n### Aliases\n\n" + for alias, target in self.aliases.items(): + text_output += f"* {alias} -> {target}\n" + markdown_output += f"* `{alias}` -> `{target}`\n" + + return TextOrMarkdown(text_output, markdown_output) + + else: + # If a specific provider_id is given, filter models by that provider + provider_id = args.provider_id + filtered_models = [m for m in models if m.startswith(provider_id + "/")] + + if not filtered_models: + return TextOrMarkdown( + f"No models found for provider '{provider_id}'.", + f"No models found for provider `{provider_id}`.", + ) - prompt = cell.strip() + text_output = f"Available models for provider '{provider_id}'\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers/{provider_id})\n\n" + markdown_output = f"## Available models for provider `{provider_id}`\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers/{provider_id})\n\n" - # interpolate user namespace into prompt - ip = self.shell - prompt = prompt.format_map(FormatDict(ip.user_ns)) + for model in filtered_models: + text_output += f"* {model}\n" + markdown_output += f"* `{model}`\n" - return self.run_ai_cell(args, prompt) + return TextOrMarkdown(text_output, markdown_output) \ No newline at end of file diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt b/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt new file mode 100644 index 000000000..795a49a67 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/old_help_strings.txt @@ -0,0 +1,20 @@ +Bedrock: + +"- For Cross-Region Inference use the appropriate `Inference profile ID` (Model ID with a region prefix, e.g., `us.meta.llama3-2-11b-instruct-v1:0`). See the [inference profiles documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). \n" +"- For custom/provisioned models, specify the model ARN (Amazon Resource Name) as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n" +"The model provider must also be specified below. This is the provider of your foundation model *in lowercase*, e.g., `amazon`, `anthropic`, `cohere`, `meta`, or `mistral`." + +SageMaker Endpoints: + +"Specify an endpoint name as the model ID. " +"In addition, you must specify a region name, request schema, and response path. " +"For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deploy-models.html) " +"and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)." + +Vertex AI: + +"To use Vertex AI Generative AI you must have the langchain-google-vertexai Python package installed and either:\n\n" +"- Have credentials configured for your environment (gcloud, workload identity, etc...)\n" +"- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n\n" +"This codebase uses the google.auth library which first looks for the application credentials variable mentioned above, and then looks for system-level auth. " +"For more information, see the [Vertex AI authentication documentation](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm/)." diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt b/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt new file mode 100644 index 000000000..8ed599766 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/old_prompt_templates.txt @@ -0,0 +1,31 @@ +# for magics + +model_kwargs["prompt_templates"] = { + "code": PromptTemplate.from_template( + "{prompt}\n\nProduce output as source code only, " + "with no text or explanation before or after it." + ), + "html": PromptTemplate.from_template( + "{prompt}\n\nProduce output in HTML format only, " + "with no markup before or afterward." + ), + "image": PromptTemplate.from_template( + "{prompt}\n\nProduce output as an image only, " + "with no text before or after it." + ), + "markdown": PromptTemplate.from_template( + "{prompt}\n\nProduce output in markdown format only." + ), + "md": PromptTemplate.from_template( + "{prompt}\n\nProduce output in markdown format only." + ), + "math": PromptTemplate.from_template( + "{prompt}\n\nProduce output in LaTeX format only, " + "with $$ at the beginning and end." + ), + "json": PromptTemplate.from_template( + "{prompt}\n\nProduce output in JSON format only, " + "with nothing before or after it." + ), + "text": PromptTemplate.from_template("{prompt}"), # No customization +} diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index de99fc8bd..0c0aeb15f 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -54,8 +54,8 @@ class CellArgs(BaseModel): # Should match CellArgs -class ErrorArgs(BaseModel): - type: Literal["error"] = "error" +class FixArgs(BaseModel): + type: Literal["fix"] = "fix" model_id: str format: FORMAT_CHOICES_TYPE model_parameters: Optional[str] = None @@ -79,13 +79,13 @@ class ListArgs(BaseModel): class RegisterArgs(BaseModel): - type: Literal["register"] = "register" + type: Literal["alias"] = "alias" name: str target: str class DeleteArgs(BaseModel): - type: Literal["delete"] = "delete" + type: Literal["dealias"] = "dealias" name: str @@ -182,7 +182,7 @@ def line_magic_parser(): """ -@line_magic_parser.command(name="error") +@line_magic_parser.command(name="fix") @click.argument("model_id", required=False) @click.option( "-f", @@ -219,14 +219,14 @@ def line_magic_parser(): default="{}", ) @click.pass_context -def error_subparser(context: click.Context, **kwargs): +def fix_subparser(context: click.Context, **kwargs): """ - Explains the most recent error. Takes the same options (except -r) as + Explains and fixes the most recent error. Takes the same options (except -r) as the basic `%%ai` command. """ if not kwargs["model_id"] and context.default_map: - kwargs["model_id"] = context.default_map["error_subparser"]["model_id"] - return ErrorArgs(**kwargs) + kwargs["model_id"] = context.default_map["fix_subparser"]["model_id"] + return FixArgs(**kwargs) @line_magic_parser.command(name="version") @@ -248,13 +248,16 @@ def help_subparser(): ) @click.argument("provider_id", required=False) def list_subparser(**kwargs): - """List language models, optionally scoped to PROVIDER_ID.""" + """List language models, optionally scoped to PROVIDER_ID.\n\n + If no PROVIDER_ID is given, all providers are listed. + If PROVIDER_ID is given, only models from that provider are listed. + If PROVIDER_ID is 'all', models from all providers are listed.""" return ListArgs(**kwargs) @line_magic_parser.command( - name="register", - short_help="Register a new alias. See `%ai register --help` for options.", + name="alias", + short_help="Register a new alias. See `%ai alias --help` for options.", ) @click.argument("name") @click.argument("target") @@ -264,7 +267,7 @@ def register_subparser(**kwargs): @line_magic_parser.command( - name="delete", short_help="Delete an alias. See `%ai delete --help` for options." + name="dealias", short_help="Delete an alias. See `%ai dealias --help` for options." ) @click.argument("name") def register_subparser(**kwargs): @@ -272,10 +275,6 @@ def register_subparser(**kwargs): return DeleteArgs(**kwargs) -@line_magic_parser.command( - name="update", - short_help="Update the target of an alias. See `%ai update --help` for options.", -) @click.argument("name") @click.argument("target") def register_subparser(**kwargs): diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py deleted file mode 100644 index ad0f18c13..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py +++ /dev/null @@ -1,38 +0,0 @@ -from langchain_anthropic import ChatAnthropic - -from ..base_provider import BaseProvider, EnvAuthStrategy - - -class ChatAnthropicProvider( - BaseProvider, ChatAnthropic -): # https://docs.anthropic.com/en/docs/about-claude/models - id = "anthropic-chat" - name = "Anthropic" - models = [ - "claude-2.0", - "claude-2.1", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - "claude-3-5-haiku-20241022", - "claude-3-5-sonnet-20240620", - "claude-3-5-sonnet-20241022", - ] - model_id_key = "model" - pypi_package_deps = ["anthropic"] - auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") - - @property - def allows_concurrency(self): - return False - - @classmethod - def is_api_key_exc(cls, e: Exception): - """ - Determine if the exception is an Anthropic API key error. - """ - import anthropic - - if isinstance(e, anthropic.AuthenticationError): - return e.status_code == 401 and "Invalid API Key" in str(e) - return False diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py deleted file mode 100644 index cc6552365..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py +++ /dev/null @@ -1,218 +0,0 @@ -import copy -import json -from collections.abc import Coroutine -from typing import Any - -from jsonpath_ng import parse -from langchain_aws import BedrockEmbeddings, BedrockLLM, ChatBedrock, SagemakerEndpoint -from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler -from langchain_core.outputs import LLMResult - -from ..base_provider import AwsAuthStrategy, BaseProvider, MultilineTextField, TextField -from ..embedding_providers import BaseEmbeddingsProvider - - -# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html -class BedrockProvider(BaseProvider, BedrockLLM): - id = "bedrock" - name = "Amazon Bedrock" - models = [ - "amazon.titan-text-express-v1", - "amazon.titan-text-lite-v1", - "amazon.titan-text-premier-v1:0", - "ai21.j2-ultra-v1", - "ai21.j2-mid-v1", - "ai21.jamba-instruct-v1:0", - "cohere.command-light-text-v14", - "cohere.command-text-v14", - "cohere.command-r-v1:0", - "cohere.command-r-plus-v1:0", - "meta.llama2-13b-chat-v1", - "meta.llama2-70b-chat-v1", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "meta.llama3-1-8b-instruct-v1:0", - "meta.llama3-1-70b-instruct-v1:0", - "meta.llama3-1-405b-instruct-v1:0", - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", - "mistral.mistral-large-2407-v1:0", - ] - model_id_key = "model_id" - pypi_package_deps = ["langchain-aws"] - auth_strategy = AwsAuthStrategy() - fields = [ - TextField( - key="credentials_profile_name", - label="AWS profile (optional)", - format="text", - ), - TextField(key="region_name", label="Region name (optional)", format="text"), - ] - - async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: - return await self._call_in_executor(*args, **kwargs) - - -# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html -class BedrockChatProvider(BaseProvider, ChatBedrock): - id = "bedrock-chat" - name = "Amazon Bedrock Chat" - models = [ - "amazon.titan-text-express-v1", - "amazon.titan-text-lite-v1", - "amazon.titan-text-premier-v1:0", - "anthropic.claude-v2", - "anthropic.claude-v2:1", - "anthropic.claude-instant-v1", - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-haiku-20240307-v1:0", - "anthropic.claude-3-opus-20240229-v1:0", - "anthropic.claude-3-5-haiku-20241022-v1:0", - "anthropic.claude-3-5-sonnet-20240620-v1:0", - "anthropic.claude-3-5-sonnet-20241022-v2:0", - "meta.llama2-13b-chat-v1", - "meta.llama2-70b-chat-v1", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "meta.llama3-1-8b-instruct-v1:0", - "meta.llama3-1-70b-instruct-v1:0", - "meta.llama3-1-405b-instruct-v1:0", - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", - "mistral.mistral-large-2407-v1:0", - ] - model_id_key = "model_id" - pypi_package_deps = ["langchain-aws"] - auth_strategy = AwsAuthStrategy() - fields = [ - TextField( - key="credentials_profile_name", - label="AWS profile (optional)", - format="text", - ), - TextField(key="region_name", label="Region name (optional)", format="text"), - ] - - async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: - return await self._call_in_executor(*args, **kwargs) - - async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]: - return await self._generate_in_executor(*args, **kwargs) - - @property - def allows_concurrency(self): - return not "anthropic" in self.model_id - - -class BedrockCustomProvider(BaseProvider, ChatBedrock): - id = "bedrock-custom" - name = "Amazon Bedrock (custom/provisioned)" - models = ["*"] - model_id_key = "model_id" - model_id_label = "Model ID" - pypi_package_deps = ["langchain-aws"] - auth_strategy = AwsAuthStrategy() - fields = [ - TextField(key="provider", label="Provider (required)", format="text"), - TextField(key="region_name", label="Region name (optional)", format="text"), - TextField( - key="credentials_profile_name", - label="AWS profile (optional)", - format="text", - ), - ] - help = ( - "- For Cross-Region Inference use the appropriate `Inference profile ID` (Model ID with a region prefix, e.g., `us.meta.llama3-2-11b-instruct-v1:0`). See the [inference profiles documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). \n" - "- For custom/provisioned models, specify the model ARN (Amazon Resource Name) as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n" - "The model provider must also be specified below. This is the provider of your foundation model *in lowercase*, e.g., `amazon`, `anthropic`, `cohere`, `meta`, or `mistral`." - ) - registry = True - - -# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html -class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings): - id = "bedrock" - name = "Bedrock" - models = [ - "amazon.titan-embed-text-v1", - "amazon.titan-embed-text-v2:0", - "cohere.embed-english-v3", - "cohere.embed-multilingual-v3", - ] - model_id_key = "model_id" - pypi_package_deps = ["langchain-aws"] - auth_strategy = AwsAuthStrategy() - - -class JsonContentHandler(LLMContentHandler): - content_type = "application/json" - accepts = "application/json" - - def __init__(self, request_schema, response_path): - self.request_schema = json.loads(request_schema) - self.response_path = response_path - self.response_parser = parse(response_path) - - def replace_values(self, old_val, new_val, d: dict[str, Any]): - """Replaces values of a dictionary recursively.""" - for key, val in d.items(): - if val == old_val: - d[key] = new_val - if isinstance(val, dict): - self.replace_values(old_val, new_val, val) - - return d - - def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: - request_obj = copy.deepcopy(self.request_schema) - self.replace_values("", prompt, request_obj) - request = json.dumps(request_obj).encode("utf-8") - return request - - def transform_output(self, output: bytes) -> str: - response_json = json.loads(output.read().decode("utf-8")) - matches = self.response_parser.find(response_json) - return matches[0].value - - -class SmEndpointProvider(BaseProvider, SagemakerEndpoint): - id = "sagemaker-endpoint" - name = "SageMaker endpoint" - models = ["*"] - model_id_key = "endpoint_name" - model_id_label = "Endpoint name" - # This all needs to be on one line of markdown, for use in a table - help = ( - "Specify an endpoint name as the model ID. " - "In addition, you must specify a region name, request schema, and response path. " - "For more information, see the documentation about [SageMaker endpoints deployment](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deploy-models.html) " - "and about [using magic commands with SageMaker endpoints](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#using-magic-commands-with-sagemaker-endpoints)." - ) - - pypi_package_deps = ["langchain-aws"] - auth_strategy = AwsAuthStrategy() - registry = True - fields = [ - TextField(key="region_name", label="Region name (required)", format="text"), - MultilineTextField( - key="request_schema", label="Request schema (required)", format="json" - ), - TextField( - key="response_path", label="Response path (required)", format="jsonpath" - ), - ] - - def __init__(self, *args, **kwargs): - request_schema = kwargs.pop("request_schema") - response_path = kwargs.pop("response_path") - content_handler = JsonContentHandler( - request_schema=request_schema, response_path=response_path - ) - - super().__init__(*args, **kwargs, content_handler=content_handler) - - async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: - return await self._call_in_executor(*args, **kwargs) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/cohere.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/cohere.py deleted file mode 100644 index 9ebc287a2..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/cohere.py +++ /dev/null @@ -1,40 +0,0 @@ -from langchain_cohere import ChatCohere, CohereEmbeddings - -from ..base_provider import BaseProvider, EnvAuthStrategy -from ..embedding_providers import BaseEmbeddingsProvider - - -class CohereProvider(BaseProvider, ChatCohere): - id = "cohere" - name = "Cohere" - # https://docs.cohere.com/docs/models - # note: This provider uses the Chat endpoint instead of the Generate - # endpoint, which is now deprecated. - models = [ - "command", - "command-nightly", - "command-light", - "command-light-nightly", - "command-r-plus", - "command-r", - ] - model_id_key = "model" - pypi_package_deps = ["langchain_cohere"] - auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") - - -class CohereEmbeddingsProvider(BaseEmbeddingsProvider, CohereEmbeddings): - id = "cohere" - name = "Cohere" - models = [ - "embed-english-v2.0", - "embed-english-light-v2.0", - "embed-multilingual-v2.0", - "embed-english-v3.0", - "embed-english-light-v3.0", - "embed-multilingual-v3.0", - "embed-multilingual-light-v3.0", - ] - model_id_key = "model" - pypi_package_deps = ["langchain_cohere"] - auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py deleted file mode 100644 index 0a5a99139..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py +++ /dev/null @@ -1,18 +0,0 @@ -from jupyter_ai_magics.base_provider import BaseProvider, EnvAuthStrategy -from langchain_google_genai import GoogleGenerativeAI - - -# See list of model ids here: https://ai.google.dev/gemini-api/docs/models/gemini -class GeminiProvider(BaseProvider, GoogleGenerativeAI): - id = "gemini" - name = "Gemini" - models = [ - "gemini-2.5-pro", - "gemini-2.5-flash", - "gemini-2.0-flash-lite", - "gemini-1.5-pro", - "gemini-1.5-flash", - ] - model_id_key = "model" - auth_strategy = EnvAuthStrategy(name="GOOGLE_API_KEY") - pypi_package_deps = ["langchain-google-genai"] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py deleted file mode 100644 index cfc316477..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/mistralai.py +++ /dev/null @@ -1,32 +0,0 @@ -from jupyter_ai_magics.base_provider import BaseProvider, EnvAuthStrategy -from langchain_mistralai import ChatMistralAI, MistralAIEmbeddings - -from ..embedding_providers import BaseEmbeddingsProvider - - -class MistralAIProvider(BaseProvider, ChatMistralAI): - id = "mistralai" - name = "MistralAI" - models = [ - "open-mistral-7b", - "open-mixtral-8x7b", - "open-mixtral-8x22b", - "mistral-small-latest", - "mistral-medium-latest", - "mistral-large-latest", - "codestral-latest", - ] - model_id_key = "model" - auth_strategy = EnvAuthStrategy(name="MISTRAL_API_KEY") - pypi_package_deps = ["langchain-mistralai"] - - -class MistralAIEmbeddingsProvider(BaseEmbeddingsProvider, MistralAIEmbeddings): - id = "mistralai" - name = "MistralAI" - models = [ - "mistral-embed", - ] - model_id_key = "model" - pypi_package_deps = ["langchain-mistralai"] - auth_strategy = EnvAuthStrategy(name="MISTRAL_API_KEY") diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/nvidia.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/nvidia.py deleted file mode 100644 index 46e92fa06..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/nvidia.py +++ /dev/null @@ -1,23 +0,0 @@ -from jupyter_ai_magics.base_provider import BaseProvider, EnvAuthStrategy -from langchain_nvidia_ai_endpoints import ChatNVIDIA - - -class ChatNVIDIAProvider(BaseProvider, ChatNVIDIA): - id = "nvidia-chat" - name = "NVIDIA" - models = [ - "playground_llama2_70b", - "playground_nemotron_steerlm_8b", - "playground_mistral_7b", - "playground_nv_llama2_rlhf_70b", - "playground_llama2_13b", - "playground_steerlm_llama_70b", - "playground_llama2_code_13b", - "playground_yi_34b", - "playground_mixtral_8x7b", - "playground_neva_22b", - "playground_llama2_code_34b", - ] - model_id_key = "model" - auth_strategy = EnvAuthStrategy(name="NVIDIA_API_KEY") - pypi_package_deps = ["langchain_nvidia_ai_endpoints"] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py deleted file mode 100644 index 58edd9de3..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py +++ /dev/null @@ -1,35 +0,0 @@ -from langchain_ollama import ChatOllama, OllamaEmbeddings - -from ..base_provider import BaseProvider, TextField -from ..embedding_providers import BaseEmbeddingsProvider - - -class OllamaProvider(BaseProvider, ChatOllama): - id = "ollama" - name = "Ollama" - model_id_key = "model" - help = ( - "See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. " - "Pass a model's name; for example, `deepseek-coder-v2`." - ) - models = ["*"] - registry = True - fields = [ - TextField(key="base_url", label="Base API URL (optional)", format="text"), - ] - - -class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings): - id = "ollama" - name = "Ollama" - # source: https://ollama.com/library - model_id_key = "model" - help = ( - "See [https://ollama.com/search?c=embedding](https://ollama.com/search?c=embedding) for a list of models. " - "Pass an embedding model's name; for example, `mxbai-embed-large`." - ) - models = ["*"] - registry = True - fields = [ - TextField(key="base_url", label="Base API URL (optional)", format="text"), - ] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py deleted file mode 100644 index db0113c45..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py +++ /dev/null @@ -1,170 +0,0 @@ -from langchain_openai import ( - AzureChatOpenAI, - AzureOpenAIEmbeddings, - ChatOpenAI, - OpenAI, - OpenAIEmbeddings, -) - -from ..base_provider import BaseProvider, EnvAuthStrategy, TextField -from ..embedding_providers import BaseEmbeddingsProvider - - -class OpenAIProvider(BaseProvider, OpenAI): - id = "openai" - name = "OpenAI" - models = ["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct"] - model_id_key = "model_name" - pypi_package_deps = ["langchain_openai"] - auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") - - @classmethod - def is_api_key_exc(cls, e: Exception): - """ - Determine if the exception is an OpenAI API key error. - """ - import openai - - if isinstance(e, openai.AuthenticationError): - error_details = e.json_body.get("error", {}) - return error_details.get("code") == "invalid_api_key" - return False - - -# https://platform.openai.com/docs/models/ -class ChatOpenAIProvider(BaseProvider, ChatOpenAI): - id = "openai-chat" - name = "OpenAI" - models = [ - "gpt-3.5-turbo", - "gpt-3.5-turbo-1106", - "gpt-4", - "gpt-4-turbo", - "gpt-4-turbo-preview", - "gpt-4-0613", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-4o", - "gpt-4o-2024-11-20", - "gpt-4o-mini", - "chatgpt-4o-latest", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "o1", - "o3-mini", - "o4-mini", - ] - model_id_key = "model_name" - pypi_package_deps = ["langchain_openai"] - auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") - - fields = [ - TextField( - key="openai_api_base", label="Base API URL (optional)", format="text" - ), - TextField( - key="openai_organization", label="Organization (optional)", format="text" - ), - TextField(key="openai_proxy", label="Proxy (optional)", format="text"), - ] - - @classmethod - def is_api_key_exc(cls, e: Exception): - """ - Determine if the exception is an OpenAI API key error. - """ - import openai - - if isinstance(e, openai.AuthenticationError): - error_details = e.json_body.get("error", {}) - return error_details.get("code") == "invalid_api_key" - return False - - -class ChatOpenAICustomProvider(BaseProvider, ChatOpenAI): - id = "openai-chat-custom" - name = "OpenAI (general interface)" - models = ["*"] - model_id_key = "model_name" - model_id_label = "Model ID" - pypi_package_deps = ["langchain_openai"] - auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") - fields = [ - TextField( - key="openai_api_base", label="Base API URL (optional)", format="text" - ), - TextField( - key="openai_organization", label="Organization (optional)", format="text" - ), - TextField(key="openai_proxy", label="Proxy (optional)", format="text"), - ] - help = "Supports non-OpenAI models that use the OpenAI API interface. Replace the OpenAI API key with the API key for the chosen provider." - registry = True - - -class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI): - id = "azure-chat-openai" - name = "Azure OpenAI" - models = ["*"] - model_id_key = "azure_deployment" - model_id_label = "Deployment name" - pypi_package_deps = ["langchain_openai"] - # Confusingly, langchain uses both OPENAI_API_KEY and AZURE_OPENAI_API_KEY for azure - # https://github.com/langchain-ai/langchain/blob/f2579096993ae460516a0aae1d3e09f3eb5c1772/libs/partners/openai/langchain_openai/llms/azure.py#L85 - auth_strategy = EnvAuthStrategy( - name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key" - ) - registry = True - - fields = [ - TextField(key="azure_endpoint", label="Base API URL (required)", format="text"), - TextField(key="api_version", label="API version (required)", format="text"), - ] - - -class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): - id = "openai" - name = "OpenAI" - models = [ - "text-embedding-ada-002", - "text-embedding-3-small", - "text-embedding-3-large", - ] - model_id_key = "model" - pypi_package_deps = ["langchain_openai"] - auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") - - -class OpenAIEmbeddingsCustomProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): - id = "openai-custom" - name = "OpenAI (general interface)" - models = ["*"] - model_id_key = "model" - pypi_package_deps = ["langchain_openai"] - auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") - registry = True - fields = [ - TextField( - key="openai_api_base", label="Base API URL (optional)", format="text" - ), - ] - help = "Supports non-OpenAI embedding models that use the OpenAI API interface. Replace the OpenAI API key with the API key for the chosen provider." - - -class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings): - id = "azure" - name = "Azure OpenAI" - models = [ - "text-embedding-ada-002", - "text-embedding-3-small", - "text-embedding-3-large", - ] - model_id_key = "azure_deployment" - pypi_package_deps = ["langchain_openai"] - auth_strategy = EnvAuthStrategy( - name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key" - ) - fields = [ - TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"), - ] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py deleted file mode 100644 index ac07ec129..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py +++ /dev/null @@ -1,51 +0,0 @@ -from jupyter_ai_magics import BaseProvider -from jupyter_ai_magics.base_provider import EnvAuthStrategy, TextField -from langchain_core.utils import get_from_dict_or_env -from langchain_openai import ChatOpenAI - - -class ChatOpenRouter(ChatOpenAI): - @property - def lc_secrets(self) -> dict[str, str]: - return {"openai_api_key": "OPENROUTER_API_KEY"} - - -class OpenRouterProvider(BaseProvider, ChatOpenRouter): - id = "openrouter" - name = "OpenRouter" - models = [ - "*" - ] # OpenRouter supports multiple models, so we use "*" to indicate it's a registry - model_id_key = "model_name" - pypi_package_deps = ["langchain_openai"] - auth_strategy = EnvAuthStrategy(name="OPENROUTER_API_KEY") - registry = True - - fields = [ - TextField( - key="openai_api_base", label="API Base URL (optional)", format="text" - ), - ] - - def __init__(self, **kwargs): - openrouter_api_key = get_from_dict_or_env( - kwargs, key="openrouter_api_key", env_key="OPENROUTER_API_KEY", default=None - ) - openrouter_api_base = kwargs.pop( - "openai_api_base", "https://openrouter.ai/api/v1" - ) - kwargs.pop("openrouter_api_key", None) - super().__init__( - openai_api_key=openrouter_api_key, - openai_api_base=openrouter_api_base, - **kwargs, - ) - - @classmethod - def is_api_key_exc(cls, e: Exception): - import openai - - if isinstance(e, openai.AuthenticationError): - error_details = e.json_body.get("error", {}) - return error_details.get("code") == "invalid_api_key" - return False diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/vertexai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/vertexai.py deleted file mode 100644 index d22210a68..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/vertexai.py +++ /dev/null @@ -1,39 +0,0 @@ -from jupyter_ai_magics.base_provider import BaseProvider -from langchain_google_vertexai import VertexAI, VertexAIEmbeddings - - -class VertexAIProvider(BaseProvider, VertexAI): - id = "vertexai" - name = "Vertex AI" - models = [ - "gemini-2.5-pro", - "gemini-2.5-flash", - ] - model_id_key = "model" - auth_strategy = None - pypi_package_deps = ["langchain-google-vertexai"] - help = ( - "To use Vertex AI Generative AI you must have the langchain-google-vertexai Python package installed and either:\n\n" - "- Have credentials configured for your environment (gcloud, workload identity, etc...)\n" - "- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n\n" - "This codebase uses the google.auth library which first looks for the application credentials variable mentioned above, and then looks for system-level auth. " - "For more information, see the [Vertex AI authentication documentation](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm/)." - ) - - -class VertexAIEmbeddingsProvider(BaseProvider, VertexAIEmbeddings): - id = "vertexai" - name = "Vertex AI" - models = [ - "text-embedding-004", - ] - model_id_key = "model" - auth_strategy = None - pypi_package_deps = ["langchain-google-vertexai"] - help = ( - "To use Vertex AI Generative AI you must have the langchain-google-vertexai Python package installed and either:\n\n" - "- Have credentials configured for your environment (gcloud, workload identity, etc...)\n" - "- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n\n" - "This codebase uses the google.auth library which first looks for the application credentials variable mentioned above, and then looks for system-level auth. " - "For more information, see the [Vertex AI authentication documentation](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm/)." - ) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py deleted file mode 100644 index cb03951fa..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ /dev/null @@ -1,247 +0,0 @@ -import base64 -import io -import json -from collections.abc import Coroutine -from typing import ( - Any, - Optional, -) - -from langchain.prompts import ( - PromptTemplate, -) -from langchain_community.chat_models import QianfanChatEndpoint -from langchain_community.llms import AI21, GPT4All, HuggingFaceEndpoint, Together - -# for backward compatibility the import group below includes -# imports of objects which used to be defined here but were -# only used by BaseProvider that has now moved to `.base_provider`. -from .base_provider import ( - CHAT_DEFAULT_TEMPLATE, - CHAT_SYSTEM_PROMPT, - COMPLETION_DEFAULT_TEMPLATE, - COMPLETION_SYSTEM_PROMPT, - HUMAN_MESSAGE_TEMPLATE, - AuthStrategy, - AwsAuthStrategy, - BaseProvider, - EnvAuthStrategy, - Field, - IntegerField, - MultiEnvAuthStrategy, - MultilineTextField, - TextField, -) - - -class AI21Provider(BaseProvider, AI21): - id = "ai21" - name = "AI21" - models = [ - "j1-large", - "j1-grande", - "j1-jumbo", - "j1-grande-instruct", - "j2-large", - "j2-grande", - "j2-jumbo", - "j2-grande-instruct", - "j2-jumbo-instruct", - ] - model_id_key = "model" - pypi_package_deps = ["ai21"] - auth_strategy = EnvAuthStrategy(name="AI21_API_KEY") - - async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: - return await self._call_in_executor(*args, **kwargs) - - @classmethod - def is_api_key_exc(cls, e: Exception): - """ - Determine if the exception is an AI21 API key error. - """ - if isinstance(e, ValueError): - return "status code 401" in str(e) - return False - - -class GPT4AllProvider(BaseProvider, GPT4All): - def __init__(self, **kwargs): - model = kwargs.get("model_id") - if model == "ggml-gpt4all-l13b-snoozy": - kwargs["backend"] = "llama" - else: - kwargs["backend"] = "gptj" - - kwargs["allow_download"] = False - n_threads = kwargs.get("n_threads", None) - if n_threads is not None: - kwargs["n_threads"] = max(int(n_threads), 1) - super().__init__(**kwargs) - - id = "gpt4all" - name = "GPT4All" - models = [ - "ggml-gpt4all-j-v1.2-jazzy", - "ggml-gpt4all-j-v1.3-groovy", - # this one needs llama backend and has licence restriction - "ggml-gpt4all-l13b-snoozy", - "mistral-7b-openorca.Q4_0", - "mistral-7b-instruct-v0.1.Q4_0", - "gpt4all-falcon-q4_0", - "wizardlm-13b-v1.2.Q4_0", - "nous-hermes-llama2-13b.Q4_0", - "gpt4all-13b-snoozy-q4_0", - "mpt-7b-chat-merges-q4_0", - "orca-mini-3b-gguf2-q4_0", - "starcoder-q4_0", - "rift-coder-v0-7b-q4_0", - "em_german_mistral_v01.Q4_0", - ] - model_id_key = "model" - pypi_package_deps = ["gpt4all"] - auth_strategy = None - fields = [IntegerField(key="n_threads", label="CPU thread count (optional)")] - - async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: - return await self._call_in_executor(*args, **kwargs) - - @property - def allows_concurrency(self): - # At present, GPT4All providers fail with concurrent messages. See #481. - return False - - -# References for using HuggingFaceEndpoint and InferenceClient: -# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient -# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/huggingface_endpoint.py -class HfHubProvider(BaseProvider, HuggingFaceEndpoint): - id = "huggingface_hub" - name = "Hugging Face Hub" - models = ["*"] - model_id_key = "repo_id" - help = ( - "See [https://huggingface.co/models](https://huggingface.co/models) for a list of models. " - "Pass a model's repository ID as the model ID; for example, `huggingface_hub:ExampleOwner/example-model`." - ) - # ipywidgets needed to suppress tqdm warning - # https://stackoverflow.com/questions/67998191 - # tqdm is a dependency of huggingface_hub - pypi_package_deps = ["huggingface_hub", "ipywidgets"] - auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") - registry = True - - # Handle text and image outputs - def _call( - self, prompt: str, stop: Optional[list[str]] = None, **kwargs: Any - ) -> str: - """Call out to Hugging Face Hub's inference endpoint. - - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - The string or image generated by the model. - - Example: - .. code-block:: python - - response = hf("Tell me a joke.") - """ - invocation_params = self._invocation_params(stop, **kwargs) - invocation_params["stop"] = invocation_params[ - "stop_sequences" - ] # porting 'stop_sequences' into the 'stop' argument - response = self.client.post( - json={"inputs": prompt, "parameters": invocation_params}, - stream=False, - task=self.task, - ) - - try: - if "generated_text" in str(response): - # text2 text or text-generation task - response_text = json.loads(response.decode())[0]["generated_text"] - # Maybe the generation has stopped at one of the stop sequences: - # then we remove this stop sequence from the end of the generated text - for stop_seq in invocation_params["stop_sequences"]: - if response_text[-len(stop_seq) :] == stop_seq: - response_text = response_text[: -len(stop_seq)] - return response_text - else: - # text-to-image task - # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_to_image.example - # Custom code for responding to image generation responses - image = self.client.text_to_image(prompt) - imageFormat = image.format # Presume it's a PIL ImageFile - mimeType = "" - if imageFormat == "JPEG": - mimeType = "image/jpeg" - elif imageFormat == "PNG": - mimeType = "image/png" - elif imageFormat == "GIF": - mimeType = "image/gif" - else: - raise ValueError(f"Unrecognized image format {imageFormat}") - buffer = io.BytesIO() - image.save(buffer, format=imageFormat) - # # Encode image data to Base64 bytes, then decode bytes to str - return ( - mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode() - ) - except: - raise ValueError( - "Task not supported, only text-generation and text-to-image tasks are valid." - ) - - async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: - return await self._call_in_executor(*args, **kwargs) - - -class TogetherAIProvider(BaseProvider, Together): - id = "togetherai" - name = "Together AI" - model_id_key = "model" - models = [ - "Austism/chronos-hermes-13b", - "DiscoResearch/DiscoLM-mixtral-8x7b-v2", - "EleutherAI/llemma_7b", - "Gryphe/MythoMax-L2-13b", - "Meta-Llama/Llama-Guard-7b", - "Nexusflow/NexusRaven-V2-13B", - "NousResearch/Nous-Capybara-7B-V1p9", - "NousResearch/Nous-Hermes-2-Yi-34B", - "NousResearch/Nous-Hermes-Llama2-13b", - "NousResearch/Nous-Hermes-Llama2-70b", - ] - pypi_package_deps = ["together"] - auth_strategy = EnvAuthStrategy(name="TOGETHER_API_KEY") - - def __init__(self, **kwargs): - model = kwargs.get("model_id") - - if model not in self.models: - kwargs["responses"] = [ - "Model not supported! Please check model list with %ai list" - ] - - super().__init__(**kwargs) - - def get_prompt_template(self, format) -> PromptTemplate: - if format == "code": - return PromptTemplate.from_template( - "{prompt}\n\nProduce output as source code only, " - "with no text or explanation before or after it." - ) - return super().get_prompt_template(format) - - -# Baidu QianfanChat provider. temporarily living as a separate class until -class QianfanProvider(BaseProvider, QianfanChatEndpoint): - id = "qianfan" - name = "ERNIE-Bot" - models = ["ERNIE-Bot", "ERNIE-Bot-4"] - model_id_key = "model_name" - pypi_package_deps = ["qianfan"] - auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"]) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed b/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py deleted file mode 100644 index 4029cd5bd..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_base_provider.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import ClassVar, Optional - -from pydantic import BaseModel - -from ..providers import BaseProvider - - -def test_provider_classvars(): - """ - Asserts that class attributes are not omitted due to parent classes defining - an instance field of the same name. This was a bug present in Pydantic v1, - which led to an issue documented in #558. - - This bug is fixed as of `pydantic==2.10.2`, but we will keep this test in - case this behavior changes in future releases. - """ - - class Parent(BaseModel): - test: Optional[str] = None - - class Base(BaseModel): - test: ClassVar[str] - - class Child(Base, Parent): - test: ClassVar[str] = "expected" - - assert Child.test == "expected" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_imports.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_imports.py deleted file mode 100644 index b6ffb4bf5..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_imports.py +++ /dev/null @@ -1,37 +0,0 @@ -import ast -import inspect -import sys - -import jupyter_ai_magics - - -def test_uses_lazy_imports(): - assert "jupyter_ai_magics.exception" not in sys.modules - jupyter_ai_magics.exception - assert "jupyter_ai_magics.exception" in sys.modules - - -def test_all_includes_all_dynamic_imports(): - dynamic_imports = set(jupyter_ai_magics._modules_by_export.keys()) - assert dynamic_imports - set(jupyter_ai_magics.__all__) == set() - - -def test_dir_returns_all(): - assert set(jupyter_ai_magics.__all__) == set(dir(jupyter_ai_magics)) - - -def test_all_type_checked(): - dynamic_imports = set(jupyter_ai_magics._modules_by_export.keys()) - tree = ast.parse(inspect.getsource(jupyter_ai_magics)) - imports_in_type_checking = { - alias.asname if alias.asname else alias.name - for node in ast.walk(tree) - if isinstance(node, ast.If) - and isinstance(node.test, ast.Name) - and node.test.id == "TYPE_CHECKING" - for stmt in node.body - if isinstance(stmt, ast.ImportFrom) - for alias in stmt.names - } - - assert imports_in_type_checking == set(dynamic_imports) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py index eb087a376..a393029d6 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py @@ -1,11 +1,7 @@ -import os -from unittest.mock import Mock, patch +from unittest.mock import patch -import pytest from IPython import InteractiveShell -from IPython.core.display import Markdown from jupyter_ai_magics.magics import AiMagics -from langchain_core.messages import AIMessage, HumanMessage from pytest import fixture from traitlets.config.loader import Config @@ -18,14 +14,14 @@ def ip() -> InteractiveShell: def test_aliases_config(ip): - ip.config.AiMagics.aliases = {"my_custom_alias": "my_provider:my_model"} + ip.config.AiMagics.initial_aliases = {"my_custom_alias": "my_provider:my_model"} ip.extension_manager.load_extension("jupyter_ai_magics") providers_list = ip.run_line_magic("ai", "list").text assert "my_custom_alias" in providers_list def test_default_model_cell(ip): - ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.config.AiMagics.initial_language_model = "my-favourite-llm" ip.extension_manager.load_extension("jupyter_ai_magics") with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run: ip.run_cell_magic("ai", "", cell="Write code for me please") @@ -35,7 +31,7 @@ def test_default_model_cell(ip): def test_non_default_model_cell(ip): - ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.config.AiMagics.initial_language_model = "my-favourite-llm" ip.extension_manager.load_extension("jupyter_ai_magics") with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run: ip.run_cell_magic("ai", "some-different-llm", cell="Write code for me please") @@ -45,79 +41,18 @@ def test_non_default_model_cell(ip): def test_default_model_error_line(ip): - ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.config.AiMagics.initial_language_model = "my-favourite-llm" ip.extension_manager.load_extension("jupyter_ai_magics") - with patch.object(AiMagics, "handle_error", return_value=None) as mock_run: - ip.run_cell_magic("ai", "error", cell=None) + with patch.object(AiMagics, "handle_fix", return_value=None) as mock_run: + ip.run_cell_magic("ai", "fix", cell=None) assert mock_run.called cell_args = mock_run.call_args.args[0] assert cell_args.model_id == "my-favourite-llm" -PROMPT = HumanMessage( - content=("Write code for me please\n\nProduce output in markdown format only.") -) -RESPONSE = AIMessage(content="Leet code") -AI1 = AIMessage("ai1") -H1 = HumanMessage("h1") -AI2 = AIMessage("ai2") -H2 = HumanMessage("h2") -AI3 = AIMessage("ai3") - - -@pytest.mark.parametrize( - ["transcript", "max_history", "expected_context"], - [ - ([], 3, [PROMPT]), - ([AI1], 0, [PROMPT]), - ([AI1], 1, [AI1, PROMPT]), - ([H1, AI1], 0, [PROMPT]), - ([H1, AI1], 1, [H1, AI1, PROMPT]), - ([AI1, H1, AI2], 0, [PROMPT]), - ([AI1, H1, AI2], 1, [H1, AI2, PROMPT]), - ([AI1, H1, AI2], 2, [AI1, H1, AI2, PROMPT]), - ([H1, AI1, H2, AI2], 0, [PROMPT]), - ([H1, AI1, H2, AI2], 1, [H2, AI2, PROMPT]), - ([H1, AI1, H2, AI2], 2, [H1, AI1, H2, AI2, PROMPT]), - ([AI1, H1, AI2, H2, AI3], 0, [PROMPT]), - ([AI1, H1, AI2, H2, AI3], 1, [H2, AI3, PROMPT]), - ([AI1, H1, AI2, H2, AI3], 2, [H1, AI2, H2, AI3, PROMPT]), - ([AI1, H1, AI2, H2, AI3], 3, [AI1, H1, AI2, H2, AI3, PROMPT]), - ], -) -def test_max_history(ip, transcript, max_history, expected_context): - ip.extension_manager.load_extension("jupyter_ai_magics") - ai_magics = ip.magics_manager.registry["AiMagics"] - ai_magics.transcript = transcript.copy() - ai_magics.max_history = max_history - provider = ai_magics._get_provider("openrouter") - with ( - patch.object(provider, "generate") as generate, - patch.dict(os.environ, OPENROUTER_API_KEY="123"), - ): - generate.return_value.generations = [[Mock(text="Leet code")]] - result = ip.run_cell_magic( - "ai", - "openrouter:anthropic/claude-3.5-sonnet", - cell="Write code for me please", - ) - provider.generate.assert_called_once_with([expected_context]) - assert isinstance(result, Markdown) - assert result.data == "Leet code" - assert result.filename is None - assert result.metadata == { - "jupyter_ai": { - "model_id": "anthropic/claude-3.5-sonnet", - "provider_id": "openrouter", - } - } - assert result.url is None - assert ai_magics.transcript == [*transcript, PROMPT, RESPONSE] - - def test_reset(ip): ip.extension_manager.load_extension("jupyter_ai_magics") ai_magics = ip.magics_manager.registry["AiMagics"] - ai_magics.transcript = [AI1, H1, AI2, H2, AI3] + ai_magics.transcript = [{"role": "user", "content": "hello"}] ip.run_line_magic("ai", "reset") assert ai_magics.transcript == [] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py deleted file mode 100644 index 860bb02bc..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Jupyter Development Team. -# Distributed under the terms of the Modified BSD License. - -import pytest -from jupyter_ai_magics.utils import get_lm_providers - -KNOWN_LM_A = "openai" -KNOWN_LM_B = "huggingface_hub" - - -@pytest.mark.parametrize( - "restrictions", - [ - {"allowed_providers": None, "blocked_providers": None}, - {"allowed_providers": None, "blocked_providers": []}, - {"allowed_providers": None, "blocked_providers": [KNOWN_LM_B]}, - {"allowed_providers": [KNOWN_LM_A], "blocked_providers": []}, - {"allowed_providers": [KNOWN_LM_A], "blocked_providers": None}, - ], -) -def test_get_lm_providers_not_restricted(restrictions): - a_not_restricted = get_lm_providers(None, restrictions) - assert KNOWN_LM_A in a_not_restricted - - -@pytest.mark.parametrize( - "restrictions", - [ - {"allowed_providers": [], "blocked_providers": None}, - {"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]}, - {"allowed_providers": None, "blocked_providers": [KNOWN_LM_A]}, - {"allowed_providers": [KNOWN_LM_B], "blocked_providers": []}, - {"allowed_providers": [KNOWN_LM_B], "blocked_providers": None}, - ], -) -def test_get_lm_providers_restricted(restrictions): - a_not_restricted = get_lm_providers(None, restrictions) - assert KNOWN_LM_A not in a_not_restricted diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py deleted file mode 100644 index 50cd187bc..000000000 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ /dev/null @@ -1,134 +0,0 @@ -import logging -from typing import Literal, Optional, Union - -from importlib_metadata import entry_points -from jupyter_ai_magics.aliases import MODEL_ID_ALIASES -from jupyter_ai_magics.base_provider import BaseProvider -from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider - -Logger = Union[logging.Logger, logging.LoggerAdapter] -LmProvidersDict = dict[str, BaseProvider] -EmProvidersDict = dict[str, BaseEmbeddingsProvider] -AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider] -ProviderDict = dict[str, AnyProvider] -ProviderRestrictions = dict[ - Literal["allowed_providers", "blocked_providers"], Optional[list[str]] -] - - -def get_lm_providers( - log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None -) -> LmProvidersDict: - if not log: - log = logging.getLogger() - log.addHandler(logging.NullHandler()) - if not restrictions: - restrictions = {"allowed_providers": None, "blocked_providers": None} - providers = {} - eps = entry_points() - provider_ep_group = eps.select(group="jupyter_ai.model_providers") - for provider_ep in provider_ep_group: - try: - provider = provider_ep.load() - except ImportError as e: - log.warning( - f"Unable to load model provider `{provider_ep.name}`. Please install the `{e.name}` package." - ) - continue - except Exception as e: - log.warning( - f"Unable to load model provider `{provider_ep.name}`", exc_info=e - ) - continue - - if not is_provider_allowed(provider.id, restrictions): - log.info(f"Skipping blocked provider `{provider.id}`.") - continue - providers[provider.id] = provider - log.info(f"Registered model provider `{provider.id}`.") - - return providers - - -def get_em_providers( - log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None -) -> EmProvidersDict: - if not log: - log = logging.getLogger() - log.addHandler(logging.NullHandler()) - if not restrictions: - restrictions = {"allowed_providers": None, "blocked_providers": None} - providers = {} - eps = entry_points() - model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") - for model_provider_ep in model_provider_eps: - try: - provider = model_provider_ep.load() - except Exception as e: - log.warning( - f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`: %s.", - e, - ) - continue - if not is_provider_allowed(provider.id, restrictions): - log.info(f"Skipping blocked provider `{provider.id}`.") - continue - providers[provider.id] = provider - log.info(f"Registered embeddings model provider `{provider.id}`.") - - return providers - - -def decompose_model_id( - model_id: str, providers: dict[str, BaseProvider] -) -> tuple[str, str]: - """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" - if model_id in MODEL_ID_ALIASES: - model_id = MODEL_ID_ALIASES[model_id] - - if ":" not in model_id: - # case: model ID was not provided with a prefix indicating the provider - # ID. try to infer the provider ID before returning (None, None). - - # naively search through the dictionary and return the first provider - # that provides a model of the same ID. - for provider_id, provider in providers.items(): - if model_id in provider.models: - return (provider_id, model_id) - - return (None, None) - - provider_id, local_model_id = model_id.split(":", 1) - return (provider_id, local_model_id) - - -def get_lm_provider( - model_id: str, lm_providers: LmProvidersDict -) -> tuple[str, type[BaseProvider]]: - """Gets a two-tuple (, ) specified by a - global model ID.""" - return _get_provider(model_id, lm_providers) - - -def get_em_provider( - model_id: str, em_providers: EmProvidersDict -) -> tuple[str, type[BaseEmbeddingsProvider]]: - """Gets a two-tuple (, ) specified by a - global model ID.""" - return _get_provider(model_id, em_providers) - - -def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool: - allowed = restrictions["allowed_providers"] - blocked = restrictions["blocked_providers"] - if blocked is not None and provider_id in blocked: - return False - if allowed is not None and provider_id not in allowed: - return False - return True - - -def _get_provider(model_id: str, providers: ProviderDict): - provider_id, local_model_id = decompose_model_id(model_id, providers) - provider = providers.get(provider_id, None) - return local_model_id, provider diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 7e54f8ee0..b1d7b6c9e 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -22,15 +22,10 @@ dynamic = ["version", "description", "authors", "urls", "keywords"] dependencies = [ "ipython", - "importlib_metadata>=5.2.0", - "langchain>=0.3.0,<0.4.0", - "langchain_community>=0.3.0,<0.4.0", # pydantic <2.10.0 raises a "protected namespaces" error in JAI # - See: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.protected_namespaces "pydantic>=2.10.0,<3", "click>=8.1.0,<9", - "jsonpath-ng>=1.5.3,<2", - "langchain-google-vertexai", ] [project.optional-dependencies] @@ -45,64 +40,10 @@ dev = [ test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"] all = [ - "ai21", - "gpt4all", - "huggingface_hub", - "ipywidgets", - "langchain_anthropic", - "langchain_aws", - "langchain_cohere", - # Pin cohere to <5.16 to prevent langchain_cohere from breaking as it uses ChatResponse removed in cohere 5.16.0 - # TODO: remove cohere pin when langchain_cohere is updated to work with cohere >=5.16 - "cohere<5.16", - "langchain_google_genai", - "langchain_mistralai", - "langchain_nvidia_ai_endpoints", - "langchain_openai", - "langchain_ollama", - "pillow", + # Required for using Amazon Bedrock "boto3", - "qianfan", - "together", - "langchain-google-vertexai", ] -[project.entry-points."jupyter_ai.model_providers"] -ai21 = "jupyter_ai_magics:AI21Provider" -anthropic-chat = "jupyter_ai_magics.partner_providers.anthropic:ChatAnthropicProvider" -cohere = "jupyter_ai_magics.partner_providers.cohere:CohereProvider" -gpt4all = "jupyter_ai_magics:GPT4AllProvider" -huggingface_hub = "jupyter_ai_magics:HfHubProvider" -ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaProvider" -openai = "jupyter_ai_magics.partner_providers.openai:OpenAIProvider" -openai-chat = "jupyter_ai_magics.partner_providers.openai:ChatOpenAIProvider" -openai-chat-custom = "jupyter_ai_magics.partner_providers.openai:ChatOpenAICustomProvider" -azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIProvider" -sagemaker-endpoint = "jupyter_ai_magics.partner_providers.aws:SmEndpointProvider" -amazon-bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockProvider" -amazon-bedrock-chat = "jupyter_ai_magics.partner_providers.aws:BedrockChatProvider" -amazon-bedrock-custom = "jupyter_ai_magics.partner_providers.aws:BedrockCustomProvider" -qianfan = "jupyter_ai_magics:QianfanProvider" -nvidia-chat = "jupyter_ai_magics.partner_providers.nvidia:ChatNVIDIAProvider" -together-ai = "jupyter_ai_magics:TogetherAIProvider" -gemini = "jupyter_ai_magics.partner_providers.gemini:GeminiProvider" -mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIProvider" -openrouter = "jupyter_ai_magics.partner_providers.openrouter:OpenRouterProvider" -vertexai = "jupyter_ai_magics.partner_providers.vertexai:VertexAIProvider" - -[project.entry-points."jupyter_ai.embeddings_model_providers"] -azure = "jupyter_ai_magics.partner_providers.openai:AzureOpenAIEmbeddingsProvider" -bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockEmbeddingsProvider" -cohere = "jupyter_ai_magics.partner_providers.cohere:CohereEmbeddingsProvider" -mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIEmbeddingsProvider" -gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider" -huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider" -ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaEmbeddingsProvider" -openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider" -openai-custom = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsCustomProvider" -qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider" -vertexai = "jupyter_ai_magics.partner_providers.vertexai:VertexAIEmbeddingsProvider" - [tool.hatch.version] source = "nodejs" diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/pyproject.toml b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/pyproject.toml index d01a93dda..7480cee63 100644 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/pyproject.toml +++ b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/pyproject.toml @@ -29,9 +29,6 @@ dependencies = ["jupyter_ai"] [project.optional-dependencies] test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"] -[project.entry-points."jupyter_ai.model_providers"] -test-provider = "{{ cookiecutter.python_name }}.provider:TestProvider" - [tool.hatch.build.hooks.version] path = "{{ cookiecutter.python_name }}/_version.py" diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/llm.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/llm.py deleted file mode 100644 index fc5134b35..000000000 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/llm.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any, Optional - -from langchain_core.callbacks.manager import CallbackManagerForLLMRun -from langchain_core.language_models.llms import LLM - - -class TestLLM(LLM): - model_id: str - - @property - def _llm_type(self) -> str: - return "custom" - - def _call( - self, - prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - return f"Hello! I am a model for testing only. Model ID: {self.model_id}" diff --git a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/provider.py b/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/provider.py deleted file mode 100644 index 694049f6b..000000000 --- a/packages/jupyter-ai-module-cookiecutter/{{cookiecutter.root_dir_name}}/{{cookiecutter.python_name}}/provider.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import ClassVar - -from jupyter_ai import AuthStrategy, BaseProvider, Field - -from .llm import TestLLM - - -class TestProvider(BaseProvider, TestLLM): - """ - A test model provider implementation for developers to build from. A model - provider inherits from 2 classes: 1) the `BaseProvider` class from - `jupyter_ai`, and 2) an LLM class from `langchain`, i.e. a class inheriting - from `LLM` or `BaseChatModel`. - - Any custom model first requires a `langchain` LLM class implementation. - Please import one from `langchain`, or refer to the `langchain` docs for - instructions on how to write your own. We offer an example in `./llm.py` for - testing. - - To create a custom model provider from an existing `langchain` - implementation, developers should edit this class' declaration to - - ``` - class TestModelProvider(BaseProvider, ): - ... - ``` - - Developers should fill in each of the below required class attributes. - As the implementation is provided by the inherited LLM class, developers - generally don't need to implement any methods. See the built-in - implementations in `jupyter_ai_magics.providers.py` for further reference. - - The provider is made available to Jupyter AI by the entry point declared in - `pyproject.toml`. If this class or parent module is renamed, make sure the - update the entry point there as well. - """ - - id: ClassVar[str] = "test-provider" - """ID for this provider class.""" - - name: ClassVar[str] = "Test Provider" - """User-facing name of this provider.""" - - models: ClassVar[list[str]] = ["test-model-1"] - """List of supported models by their IDs. For registry providers, this will - be just ["*"].""" - - help: ClassVar[str] = None - """Text to display in lieu of a model list for a registry provider that does - not provide a list of models.""" - - model_id_key: ClassVar[str] = "model_id" - """Kwarg expected by the upstream LangChain provider.""" - - model_id_label: ClassVar[str] = "Model ID" - """Human-readable label of the model ID.""" - - pypi_package_deps: ClassVar[list[str]] = [] - """List of PyPi package dependencies.""" - - auth_strategy: ClassVar[AuthStrategy] = None - """Authentication/authorization strategy. Declares what credentials are - required to use this model provider. Generally should not be `None`.""" - - registry: ClassVar[bool] = False - """Whether this provider is a registry provider.""" - - fields: ClassVar[list[Field]] = [] - """User inputs expected by this provider when initializing it. Each `Field` `f` - should be passed in the constructor as a keyword argument, keyed by `f.key`.""" diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py deleted file mode 100644 index b5ec80dc7..000000000 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py +++ /dev/null @@ -1,59 +0,0 @@ -import time -from collections.abc import Iterator -from typing import Any, Optional - -from langchain_core.callbacks.manager import CallbackManagerForLLMRun -from langchain_core.language_models.llms import LLM -from langchain_core.outputs.generation import GenerationChunk - - -class TestLLM(LLM): - model_id: str = "test" - - @property - def _llm_type(self) -> str: - return "custom" - - def _call( - self, - prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - time.sleep(3) - return f"Hello! This is a dummy response from a test LLM." - - -class TestLLMWithStreaming(LLM): - model_id: str = "test" - - @property - def _llm_type(self) -> str: - return "custom" - - def _call( - self, - prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - time.sleep(3) - return f"Hello! This is a dummy response from a test LLM." - - def _stream( - self, - prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[GenerationChunk]: - time.sleep(1) - yield GenerationChunk( - text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n", - generation_info={"test_metadata_field": "foobar"}, - ) - for i in range(1, 6): - time.sleep(0.2) - yield GenerationChunk(text=f"{i}, ") diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py b/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py deleted file mode 100644 index 4a15dd31e..000000000 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import ClassVar - -from jupyter_ai import AuthStrategy, BaseProvider, Field - -from .test_llms import TestLLM, TestLLMWithStreaming - - -class TestProvider(BaseProvider, TestLLM): - id: ClassVar[str] = "test-provider" - """ID for this provider class.""" - - name: ClassVar[str] = "Test Provider" - """User-facing name of this provider.""" - - models: ClassVar[list[str]] = ["test"] - """List of supported models by their IDs. For registry providers, this will - be just ["*"].""" - - help: ClassVar[str] = None - """Text to display in lieu of a model list for a registry provider that does - not provide a list of models.""" - - model_id_key: ClassVar[str] = "model_id" - """Kwarg expected by the upstream LangChain provider.""" - - model_id_label: ClassVar[str] = "Model ID" - """Human-readable label of the model ID.""" - - pypi_package_deps: ClassVar[list[str]] = [] - """List of PyPi package dependencies.""" - - auth_strategy: ClassVar[AuthStrategy] = None - """Authentication/authorization strategy. Declares what credentials are - required to use this model provider. Generally should not be `None`.""" - - registry: ClassVar[bool] = False - """Whether this provider is a registry provider.""" - - fields: ClassVar[list[Field]] = [] - """User inputs expected by this provider when initializing it. Each `Field` `f` - should be passed in the constructor as a keyword argument, keyed by `f.key`.""" - - -class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming): - id: ClassVar[str] = "test-provider-with-streaming" - """ID for this provider class.""" - - name: ClassVar[str] = "Test Provider (streaming)" - """User-facing name of this provider.""" - - models: ClassVar[list[str]] = ["test"] - """List of supported models by their IDs. For registry providers, this will - be just ["*"].""" - - help: ClassVar[str] = None - """Text to display in lieu of a model list for a registry provider that does - not provide a list of models.""" - - model_id_key: ClassVar[str] = "model_id" - """Kwarg expected by the upstream LangChain provider.""" - - model_id_label: ClassVar[str] = "Model ID" - """Human-readable label of the model ID.""" - - pypi_package_deps: ClassVar[list[str]] = [] - """List of PyPi package dependencies.""" - - auth_strategy: ClassVar[AuthStrategy] = None - """Authentication/authorization strategy. Declares what credentials are - required to use this model provider. Generally should not be `None`.""" - - registry: ClassVar[bool] = False - """Whether this provider is a registry provider.""" - - fields: ClassVar[list[Field]] = [] - """User inputs expected by this provider when initializing it. Each `Field` `f` - should be passed in the constructor as a keyword argument, keyed by `f.key`.""" - - -class TestProviderAskLearnUnsupported(BaseProvider, TestLLMWithStreaming): - id: ClassVar[str] = "test-provider-ask-learn-unsupported" - """ID for this provider class.""" - - name: ClassVar[str] = "Test Provider (/learn and /ask unsupported)" - """User-facing name of this provider.""" - - models: ClassVar[list[str]] = ["test"] - """List of supported models by their IDs. For registry providers, this will - be just ["*"].""" - - help: ClassVar[str] = None - """Text to display in lieu of a model list for a registry provider that does - not provide a list of models.""" - - model_id_key: ClassVar[str] = "model_id" - """Kwarg expected by the upstream LangChain provider.""" - - model_id_label: ClassVar[str] = "Model ID" - """Human-readable label of the model ID.""" - - pypi_package_deps: ClassVar[list[str]] = [] - """List of PyPi package dependencies.""" - - auth_strategy: ClassVar[AuthStrategy] = None - """Authentication/authorization strategy. Declares what credentials are - required to use this model provider. Generally should not be `None`.""" - - registry: ClassVar[bool] = False - """Whether this provider is a registry provider.""" - - fields: ClassVar[list[Field]] = [] - """User inputs expected by this provider when initializing it. Each `Field` `f` - should be passed in the constructor as a keyword argument, keyed by `f.key`.""" - - unsupported_slash_commands = {"/learn", "/ask"} diff --git a/packages/jupyter-ai-test/pyproject.toml b/packages/jupyter-ai-test/pyproject.toml index 0757aa7d9..992022662 100644 --- a/packages/jupyter-ai-test/pyproject.toml +++ b/packages/jupyter-ai-test/pyproject.toml @@ -27,11 +27,6 @@ dependencies = ["jupyter_ai"] [project.optional-dependencies] test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"] -[project.entry-points."jupyter_ai.model_providers"] -test-provider = "jupyter_ai_test.test_providers:TestProvider" -test-provider-with-streaming = "jupyter_ai_test.test_providers:TestProviderWithStreaming" -test-provider-ask-learn-unsupported = "jupyter_ai_test.test_providers:TestProviderAskLearnUnsupported" - [project.entry-points."jupyter_ai.personas"] debug-persona = "jupyter_ai_test.debug_persona:DebugPersona" diff --git a/packages/jupyter-ai/jupyter_ai/__init__.py b/packages/jupyter-ai/jupyter_ai/__init__.py index d1d6e4d37..830a04627 100644 --- a/packages/jupyter-ai/jupyter_ai/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/__init__.py @@ -5,11 +5,10 @@ # expose jupyter_ai_magics ipython extension # DO NOT REMOVE. -from jupyter_ai_magics import load_ipython_extension, unload_ipython_extension - -# expose jupyter_ai_magics providers -# DO NOT REMOVE. -from jupyter_ai_magics.providers import * +from jupyter_ai_magics import ( # type: ignore[import-untyped] + load_ipython_extension, + unload_ipython_extension, +) from ._version import __version__ from .extension import AiExtension diff --git a/packages/jupyter-ai/jupyter_ai/completions/completion_prompts.py b/packages/jupyter-ai/jupyter_ai/completions/completion_prompts.py new file mode 100644 index 000000000..10953a4e8 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/completions/completion_prompts.py @@ -0,0 +1,24 @@ +COMPLETION_SYSTEM_PROMPT = """ +You are an application built to provide helpful code completion suggestions. +You should only produce code. Keep comments to minimum, use the +programming language comment syntax. Produce clean code. +The code is written in JupyterLab, a data analysis and code development +environment which can execute code extended with additional syntax for +interactive features, such as magics. +""".strip() + +# only add the suffix bit if present to save input tokens/computation time +COMPLETION_DEFAULT_TEMPLATE = """ +The document is called `{{filename}}` and written in {{language}}. +{% if suffix %} +The code after the completion request is: + +``` +{{suffix}} +``` +{% endif %} + +Complete the following code: + +``` +{{prefix}}""" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai/jupyter_ai/completions/completion_types.py similarity index 100% rename from packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py rename to packages/jupyter-ai/jupyter_ai/completions/completion_types.py diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py b/packages/jupyter-ai/jupyter_ai/completions/completion_utils.py similarity index 97% rename from packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py rename to packages/jupyter-ai/jupyter_ai/completions/completion_utils.py index 1330e8992..150ac40c0 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py +++ b/packages/jupyter-ai/jupyter_ai/completions/completion_utils.py @@ -1,4 +1,4 @@ -from .models.completion import InlineCompletionRequest +from .completion_types import InlineCompletionRequest def token_from_request(request: InlineCompletionRequest, suggestion: int): diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index cf07550bc..fd83f43d8 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -5,21 +5,19 @@ from typing import Union import tornado -from jupyter_ai.completions.handlers.model_mixin import CompletionsModelMixin -from jupyter_ai.completions.models import ( +from jupyter_server.base.handlers import JupyterHandler +from pydantic import ValidationError + +from ..completion_types import ( CompletionError, InlineCompletionList, InlineCompletionReply, InlineCompletionRequest, InlineCompletionStreamChunk, ) -from jupyter_server.base.handlers import JupyterHandler -from pydantic import ValidationError -class BaseInlineCompletionHandler( - CompletionsModelMixin, JupyterHandler, tornado.websocket.WebSocketHandler -): +class BaseInlineCompletionHandler(JupyterHandler, tornado.websocket.WebSocketHandler): """A Tornado WebSocket handler that receives inline completion requests and fulfills them accordingly. This class is instantiated once per WebSocket connection.""" diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 38676b998..93c496a8e 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -1,4 +1,4 @@ -from ..models import InlineCompletionRequest +from ..completion_types import InlineCompletionRequest from .base import BaseInlineCompletionHandler @@ -8,17 +8,79 @@ def __init__(self, *args, **kwargs): async def handle_request(self, request: InlineCompletionRequest): """Handles an inline completion request without streaming.""" - llm = self.get_llm() - if not llm: - raise ValueError("Please select a model for inline completion.") + # TODO: migrate this to use LiteLLM + # llm = self.get_llm() + # if not llm: + # raise ValueError("Please select a model for inline completion.") - reply = await llm.generate_inline_completions(request) - self.reply(reply) + # reply = await llm.generate_inline_completions(request) + # self.reply(reply) async def handle_stream_request(self, request: InlineCompletionRequest): - llm = self.get_llm() - if not llm: - raise ValueError("Please select a model for inline completion.") + # TODO: migrate this to use LiteLLM + # llm = self.get_llm() + # if not llm: + # raise ValueError("Please select a model for inline completion.") - async for reply in llm.stream_inline_completions(request): - self.reply(reply) + # async for reply in llm.stream_inline_completions(request): + # self.reply(reply) + pass + + +# old methods on BaseProvider, for reference when migrating this to LiteLLM +# +# async def generate_inline_completions( +# self, request: InlineCompletionRequest +# ) -> InlineCompletionReply: +# chain = self._create_completion_chain() +# model_arguments = completion.template_inputs_from_request(request) +# suggestion = await chain.ainvoke(input=model_arguments) +# suggestion = completion.post_process_suggestion(suggestion, request) +# return InlineCompletionReply( +# list=InlineCompletionList(items=[{"insertText": suggestion}]), +# reply_to=request.number, +# ) + +# async def stream_inline_completions( +# self, request: InlineCompletionRequest +# ) -> AsyncIterator[InlineCompletionStreamChunk]: +# chain = self._create_completion_chain() +# token = completion.token_from_request(request, 0) +# model_arguments = completion.template_inputs_from_request(request) +# suggestion = processed_suggestion = "" + +# # send an incomplete `InlineCompletionReply`, indicating to the +# # client that LLM output is about to streamed across this connection. +# yield InlineCompletionReply( +# list=InlineCompletionList( +# items=[ +# { +# # insert text starts empty as we do not pre-generate any part +# "insertText": "", +# "isIncomplete": True, +# "token": token, +# } +# ] +# ), +# reply_to=request.number, +# ) + +# async for fragment in chain.astream(input=model_arguments): +# suggestion += fragment +# processed_suggestion = completion.post_process_suggestion( +# suggestion, request +# ) +# yield InlineCompletionStreamChunk( +# type="stream", +# response={"insertText": processed_suggestion, "token": token}, +# reply_to=request.number, +# done=False, +# ) + +# # finally, send a message confirming that we are done +# yield InlineCompletionStreamChunk( +# type="stream", +# response={"insertText": processed_suggestion, "token": token}, +# reply_to=request.number, +# done=True, +# ) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py deleted file mode 100644 index 6534a037a..000000000 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/model_mixin.py +++ /dev/null @@ -1,78 +0,0 @@ -from logging import Logger -from typing import Any, Optional - -from jupyter_ai.config_manager import ConfigManager -from jupyter_ai_magics.providers import BaseProvider - - -class CompletionsModelMixin: - """Mixin class containing methods and attributes used by completions LLM handler.""" - - handler_kind: str - settings: dict - log: Logger - - @property - def jai_config_manager(self) -> ConfigManager: - return self.settings["jai_config_manager"] - - @property - def model_parameters(self) -> dict[str, dict[str, Any]]: - return self.settings["model_parameters"] - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._llm: Optional[BaseProvider] = None - self._llm_params = None - - def get_llm(self) -> Optional[BaseProvider]: - lm_provider = self.jai_config_manager.completions_lm_provider - lm_provider_params = self.jai_config_manager.completions_lm_provider_params - - if not lm_provider or not lm_provider_params: - return None - - curr_lm_id = ( - f'{self._llm.id}:{lm_provider_params["model_id"]}' if self._llm else None - ) - next_lm_id = ( - f'{lm_provider.id}:{lm_provider_params["model_id"]}' - if lm_provider - else None - ) - - should_recreate_llm = False - if curr_lm_id != next_lm_id: - self.log.info( - f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}." - ) - should_recreate_llm = True - elif self._llm_params != lm_provider_params: - self.log.info( - f"{self.handler_kind} model params changed, updating the llm chain." - ) - should_recreate_llm = True - - if should_recreate_llm: - self._llm = self.create_llm(lm_provider, lm_provider_params) - self._llm_params = lm_provider_params - - return self._llm - - def get_model_parameters( - self, provider: type[BaseProvider], provider_params: dict[str, str] - ): - return self.model_parameters.get( - f"{provider.id}:{provider_params['model_id']}", {} - ) - - def create_llm( - self, provider: type[BaseProvider], provider_params: dict[str, str] - ) -> BaseProvider: - unified_parameters = { - **provider_params, - **(self.get_model_parameters(provider, provider_params)), - } - llm = provider(**unified_parameters) - - return llm diff --git a/packages/jupyter-ai/jupyter_ai/completions/models.py b/packages/jupyter-ai/jupyter_ai/completions/models.py deleted file mode 100644 index e9679379e..000000000 --- a/packages/jupyter-ai/jupyter_ai/completions/models.py +++ /dev/null @@ -1,17 +0,0 @@ -from jupyter_ai_magics.models.completion import ( - CompletionError, - InlineCompletionItem, - InlineCompletionList, - InlineCompletionReply, - InlineCompletionRequest, - InlineCompletionStreamChunk, -) - -__all__ = [ - "InlineCompletionRequest", - "InlineCompletionItem", - "CompletionError", - "InlineCompletionList", - "InlineCompletionReply", - "InlineCompletionStreamChunk", -] diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index db19ea29d..05bb1efe0 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -3,16 +3,9 @@ import os import time from copy import deepcopy -from typing import Optional, Union +from typing import Any, Optional, Union from deepmerge import always_merger -from jupyter_ai_magics.utils import ( - AnyProvider, - EmProvidersDict, - LmProvidersDict, - get_em_provider, - get_lm_provider, -) from jupyter_core.paths import jupyter_data_dir from traitlets import Integer, Unicode from traitlets.config import Configurable @@ -48,17 +41,6 @@ class BlockedModelError(Exception): pass -def _validate_provider_authn(config: JaiConfig, provider: type[AnyProvider]): - # TODO: handle non-env auth strategies - if not provider.auth_strategy or provider.auth_strategy.type != "env": - return - - if provider.auth_strategy.name not in config.api_keys: - raise AuthError( - f"Missing API key for '{provider.auth_strategy.name}' in the config." - ) - - def remove_none_entries(d: dict): """ Returns a deep copy of the given dictionary that excludes all top-level @@ -110,8 +92,6 @@ class ConfigManager(Configurable): def __init__( self, log: Logger, - lm_providers: LmProvidersDict, - em_providers: EmProvidersDict, defaults: dict, allowed_providers: Optional[list[str]] = None, blocked_providers: Optional[list[str]] = None, @@ -123,16 +103,14 @@ def __init__( super().__init__(*args, **kwargs) self.log = log - self._lm_providers = lm_providers - """List of LM providers.""" - self._em_providers = em_providers - """List of EM providers.""" - self._allowed_providers = allowed_providers self._blocked_providers = blocked_providers self._allowed_models = allowed_models self._blocked_models = blocked_models + self._lm_providers: dict[str, Any] = ( + {} + ) # Placeholder: should be set to actual language model providers self._defaults = remove_none_entries(defaults) self._last_read: Optional[int] = None @@ -175,69 +153,10 @@ def _process_existing_config(self): with open(self.config_path, encoding="utf-8") as f: existing_config = json.loads(f.read()) config = JaiConfig(**existing_config) - validated_config = self._validate_model_ids(config) # re-write to the file to validate the config and apply any # updates to the config file immediately - self._write_config(validated_config) - - def _validate_model_ids(self, config): - lm_provider_keys = ["model_provider_id", "completions_model_provider_id"] - em_provider_keys = ["embeddings_provider_id"] - clm_provider_keys = ["completions_model_provider_id"] - - # if the currently selected language or embedding model are - # forbidden, set them to `None` and log a warning. - for lm_key in lm_provider_keys: - lm_id = getattr(config, lm_key) - if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) - setattr(config, lm_key, None) - for em_key in em_provider_keys: - em_id = getattr(config, em_key) - if em_id is not None and not self._validate_model(em_id, raise_exc=False): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." - ) - setattr(config, em_key, None) - for clm_key in clm_provider_keys: - clm_id = getattr(config, clm_key) - if clm_id is not None and not self._validate_model(clm_id, raise_exc=False): - self.log.warning( - f"Completion model {clm_id} is forbidden by current allow/blocklists. Setting to None." - ) - setattr(config, clm_key, None) - - # if the currently selected language or embedding model ids are - # not associated with models, set them to `None` and log a warning. - for lm_key in lm_provider_keys: - lm_id = getattr(config, lm_key) - if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: - self.log.warning( - f"No language model is associated with '{lm_id}'. Setting to None." - ) - setattr(config, lm_key, None) - for em_key in em_provider_keys: - em_id = getattr(config, em_key) - if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: - self.log.warning( - f"No embedding model is associated with '{em_id}'. Setting to None." - ) - setattr(config, em_key, None) - for clm_key in clm_provider_keys: - clm_id = getattr(config, clm_key) - if ( - clm_id is not None - and not get_lm_provider(clm_id, self._lm_providers)[1] - ): - self.log.warning( - f"No completion model is associated with '{clm_id}'. Setting to None." - ) - setattr(config, clm_key, None) - - return config + self._write_config(config) def _read_config(self) -> JaiConfig: """ @@ -268,78 +187,80 @@ def _validate_config(self, config: JaiConfig): user has specified authentication for all configured models that require it. """ + # TODO: re-implement this w/ liteLLM # validate language model config - if config.model_provider_id: - _, lm_provider = get_lm_provider( - config.model_provider_id, self._lm_providers - ) + # if config.model_provider_id: + # _, lm_provider = get_lm_provider( + # config.model_provider_id, self._lm_providers + # ) - # verify model is declared by some provider - if not lm_provider: - raise ValueError( - f"No language model is associated with '{config.model_provider_id}'." - ) + # # verify model is declared by some provider + # if not lm_provider: + # raise ValueError( + # f"No language model is associated with '{config.model_provider_id}'." + # ) - # verify model is not blocked - self._validate_model(config.model_provider_id) + # # verify model is not blocked + # self._validate_model(config.model_provider_id) - # verify model is authenticated - _validate_provider_authn(config, lm_provider) + # # verify model is authenticated + # _validate_provider_authn(config, lm_provider) - # verify fields exist for this model if needed - if lm_provider.fields and config.model_provider_id not in config.fields: - config.fields[config.model_provider_id] = {} + # # verify fields exist for this model if needed + # if lm_provider.fields and config.model_provider_id not in config.fields: + # config.fields[config.model_provider_id] = {} # validate completions model config - if config.completions_model_provider_id: - _, completions_provider = get_lm_provider( - config.completions_model_provider_id, self._lm_providers - ) - - # verify model is declared by some provider - if not completions_provider: - raise ValueError( - f"No language model is associated with '{config.completions_model_provider_id}'." - ) - - # verify model is not blocked - self._validate_model(config.completions_model_provider_id) - - # verify model is authenticated - _validate_provider_authn(config, completions_provider) - - # verify completions fields exist for this model if needed - if ( - completions_provider.fields - and config.completions_model_provider_id - not in config.completions_fields - ): - config.completions_fields[config.completions_model_provider_id] = {} - - # validate embedding model config - if config.embeddings_provider_id: - _, em_provider = get_em_provider( - config.embeddings_provider_id, self._em_providers - ) - - # verify model is declared by some provider - if not em_provider: - raise ValueError( - f"No embedding model is associated with '{config.embeddings_provider_id}'." - ) - - # verify model is not blocked - self._validate_model(config.embeddings_provider_id) - - # verify model is authenticated - _validate_provider_authn(config, em_provider) - - # verify embedding fields exist for this model if needed - if ( - em_provider.fields - and config.embeddings_provider_id not in config.embeddings_fields - ): - config.embeddings_fields[config.embeddings_provider_id] = {} + # if config.completions_model_provider_id: + # _, completions_provider = get_lm_provider( + # config.completions_model_provider_id, self._lm_providers + # ) + + # # verify model is declared by some provider + # if not completions_provider: + # raise ValueError( + # f"No language model is associated with '{config.completions_model_provider_id}'." + # ) + + # # verify model is not blocked + # self._validate_model(config.completions_model_provider_id) + + # # verify model is authenticated + # _validate_provider_authn(config, completions_provider) + + # # verify completions fields exist for this model if needed + # if ( + # completions_provider.fields + # and config.completions_model_provider_id + # not in config.completions_fields + # ): + # config.completions_fields[config.completions_model_provider_id] = {} + + # # validate embedding model config + # if config.embeddings_provider_id: + # _, em_provider = get_em_provider( + # config.embeddings_provider_id, self._em_providers + # ) + + # # verify model is declared by some provider + # if not em_provider: + # raise ValueError( + # f"No embedding model is associated with '{config.embeddings_provider_id}'." + # ) + + # # verify model is not blocked + # self._validate_model(config.embeddings_provider_id) + + # # verify model is authenticated + # _validate_provider_authn(config, em_provider) + + # # verify embedding fields exist for this model if needed + # if ( + # em_provider.fields + # and config.embeddings_provider_id not in config.embeddings_fields + # ): + # config.embeddings_fields[config.embeddings_provider_id] = {} + return def _validate_model(self, model_id: str, raise_exc=True): """ @@ -350,7 +271,7 @@ def _validate_model(self, model_id: str, raise_exc=True): """ assert model_id is not None - components = model_id.split(":", 1) + components = model_id.split("/", 1) assert len(components) == 2 provider_id, _ = components @@ -399,29 +320,6 @@ def _write_config(self, new_config: JaiConfig): with open(self.config_path, "w") as f: json.dump(new_config.model_dump(), f, indent=self.indentation_depth) - def delete_api_key(self, key_name: str): - config_dict = self._read_config().model_dump() - required_keys = [] - for provider in [ - self.lm_provider, - self.em_provider, - self.completions_lm_provider, - ]: - if ( - provider - and provider.auth_strategy - and provider.auth_strategy.type == "env" - ): - required_keys.append(provider.auth_strategy.name) - - if key_name in required_keys: - raise KeyInUseError( - "This API key is currently in use by the language or embedding model. Please change the model before deleting the corresponding API key." - ) - - config_dict["api_keys"].pop(key_name, None) - self._write_config(JaiConfig(**config_dict)) - def update_config(self, config_update: UpdateConfigRequest): # type:ignore last_write = os.stat(self.config_path).st_mtime_ns if config_update.last_read and config_update.last_read < last_write: @@ -449,94 +347,56 @@ def get_config(self): ) @property - def lm_gid(self): + def chat_model(self) -> str | None: + """ + Returns the model ID of the chat model from AI settings, if any. + """ config = self._read_config() return config.model_provider_id @property - def em_gid(self): - config = self._read_config() - return config.embeddings_provider_id - - @property - def lm_provider(self): - return self._get_provider("model_provider_id", self._lm_providers) - - @property - def em_provider(self): - return self._get_provider("embeddings_provider_id", self._em_providers) - - @property - def completions_lm_provider(self): - return self._get_provider("completions_model_provider_id", self._lm_providers) + def chat_model_params(self) -> dict[str, Any]: + return self._provider_params("model_provider_id", self._lm_providers) - def _get_provider(self, key, listing): + def _provider_params( + self, provider_id_attr: str, providers: dict + ) -> dict[str, Any]: + """ + Returns the parameters for the provider specified by the given attribute. + """ config = self._read_config() - gid = getattr(config, key) - if gid is None: - return None - - _, Provider = get_lm_provider(gid, listing) - return Provider - - @property - def lm_provider_params(self): + provider_id = getattr(config, provider_id_attr, None) + if not provider_id or provider_id not in providers: + return {} + return providers[provider_id].get("params", {}) return self._provider_params("model_provider_id", self._lm_providers) @property - def em_provider_params(self): - return self._provider_params("embeddings_provider_id", self._em_providers) + def embedding_model(self) -> str | None: + """ + Returns the model ID of the embedding model from AI settings, if any. + """ + config = self._read_config() + return config.embeddings_provider_id @property - def completions_lm_provider_params(self): - return self._provider_params( - "completions_model_provider_id", self._lm_providers, completions=True - ) + def embedding_model_params(self) -> dict[str, Any]: + # TODO + return {} - def _provider_params(self, key, listing, completions: bool = False): - # read config + @property + def completion_model(self) -> str | None: + """ + Returns the model ID of the completion model from AI settings, if any. + """ config = self._read_config() + return config.completions_model_provider_id - # get model ID (without provider ID component) from model universal ID - # (with provider component). - model_uid = getattr(config, key) - if not model_uid: - return None - model_id = model_uid.split(":", 1)[1] - - # get config fields (e.g. base API URL, etc.) - if completions: - fields = config.completions_fields.get(model_uid, {}) - elif key == "embeddings_provider_id": - fields = config.embeddings_fields.get(model_uid, {}) - else: - fields = config.fields.get(model_uid, {}) - - # exclude empty fields - # TODO: modify the config manager to never save empty fields in the - # first place. - fields = { - k: None if isinstance(v, str) and not len(v) else v - for k, v in fields.items() - } - - # get authn fields - _, Provider = ( - get_em_provider(model_uid, listing) - if key == "embeddings_provider_id" - else get_lm_provider(model_uid, listing) - ) - authn_fields = {} - if Provider.auth_strategy and Provider.auth_strategy.type == "env": - keyword_param = ( - Provider.auth_strategy.keyword_param - or Provider.auth_strategy.name.lower() - ) - key_name = Provider.auth_strategy.name - authn_fields[keyword_param] = config.api_keys[key_name] + @property + def completion_model_params(self): + # TODO + return {} - return { - "model_id": model_id, - **fields, - **authn_fields, - } + def delete_api_key(self, key_name: str): + # TODO: store in .env files + pass diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 07446746a..29afdd192 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,13 +1,10 @@ import os import time -import types from asyncio import get_event_loop_policy from functools import partial from typing import TYPE_CHECKING, Optional import traitlets -from jupyter_ai_magics import BaseProvider -from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_events import EventLogger from jupyter_server.extension.application import ExtensionApp from jupyter_server.serverapp import ServerApp @@ -24,13 +21,12 @@ from .completions.handlers import DefaultInlineCompletionHandler from .config_manager import ConfigManager from .handlers import ( - ApiKeysHandler, - EmbeddingsModelProviderHandler, GlobalConfigHandler, InterruptStreamingHandler, - ModelProviderHandler, ) from .personas import PersonaManager +from .secrets.secrets_manager import EnvSecretsManager +from .secrets.secrets_rest_api import SecretsRestAPI if TYPE_CHECKING: from asyncio import AbstractEventLoop @@ -55,16 +51,17 @@ JUPYTER_COLLABORATION_EVENTS_URI, ) +from .model_providers.model_handlers import ChatModelEndpoint + class AiExtension(ExtensionApp): name = "jupyter_ai" handlers = [ # type:ignore[assignment] - (r"api/ai/api_keys/(?P\w+)/?", ApiKeysHandler), (r"api/ai/config/?", GlobalConfigHandler), (r"api/ai/chats/stop_streaming/?", InterruptStreamingHandler), - (r"api/ai/providers/?", ModelProviderHandler), - (r"api/ai/providers/embeddings/?", EmbeddingsModelProviderHandler), (r"api/ai/completion/inline/?", DefaultInlineCompletionHandler), + (r"api/ai/models/chat/?", ChatModelEndpoint), + (r"api/ai/secrets/?", SecretsRestAPI), ( r"api/ai/static/jupyternaut.svg()/?", StaticFileHandler, @@ -143,7 +140,17 @@ class AiExtension(ExtensionApp): config=True, ) - default_language_model = Unicode( + initial_chat_model = Unicode( + default_value=None, + allow_none=True, + help=""" + Default language model to use, as string in the format + :, defaults to None. + """, + config=True, + ) + + initial_language_model = Unicode( default_value=None, allow_none=True, help=""" @@ -297,22 +304,14 @@ def on_change( def initialize_settings(self): start = time.time() - # Read from allowlist and blocklist - restrictions = { - "allowed_providers": self.allowed_providers, - "blocked_providers": self.blocked_providers, - } - self.settings["allowed_models"] = self.allowed_models - self.settings["blocked_models"] = self.blocked_models + # Log traitlets configuration self.log.info(f"Configured provider allowlist: {self.allowed_providers}") self.log.info(f"Configured provider blocklist: {self.blocked_providers}") self.log.info(f"Configured model allowlist: {self.allowed_models}") self.log.info(f"Configured model blocklist: {self.blocked_models}") - self.settings["model_parameters"] = self.model_parameters self.log.info(f"Configured model parameters: {self.model_parameters}") - defaults = { - "model_provider_id": self.default_language_model, + "model_provider_id": self.initial_language_model, "embeddings_provider_id": self.default_embeddings_model, "completions_model_provider_id": self.default_completions_model, "api_keys": self.default_api_keys, @@ -321,20 +320,10 @@ def initialize_settings(self): "completions_fields": self.model_parameters, } - # Fetch LM & EM providers - self.settings["lm_providers"] = get_lm_providers( - log=self.log, restrictions=restrictions - ) - self.settings["em_providers"] = get_em_providers( - log=self.log, restrictions=restrictions - ) - + # Initialize ConfigManager self.settings["jai_config_manager"] = ConfigManager( - # traitlets configuration, not JAI configuration. config=self.config, log=self.log, - lm_providers=self.settings["lm_providers"], - em_providers=self.settings["em_providers"], allowed_providers=self.allowed_providers, blocked_providers=self.blocked_providers, allowed_models=self.allowed_models, @@ -342,23 +331,21 @@ def initialize_settings(self): defaults=defaults, ) - # Expose a subset of settings as read-only to the providers - BaseProvider.server_settings = types.MappingProxyType( - self.serverapp.web_app.settings - ) - - self.log.info("Registered providers.") - - self.log.info(f"Registered {self.name} server extension") + # Initialize SecretsManager + self.settings["jai_secrets_manager"] = EnvSecretsManager(parent=self) + # Bind event loop to settings dictionary self.settings["jai_event_loop"] = self.event_loop - # Create empty dictionary for events communicating that - # message generation/streaming got interrupted. + # Bind dictionary of interrupts to settings dictionary. + # Each key is a message ID, each value is an asyncio.Event. + # When a message's interrupt event is set, the response is halted. self.settings["jai_message_interrupted"] = {} - latency_ms = round((time.time() - start) * 1000) - self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") + # Log server extension startup time + self.log.info(f"Registered {self.name} server extension") + startup_time = round((time.time() - start) * 1000) + self.log.info(f"Initialized Jupyter AI server extension in {startup_time} ms.") async def stop_extension(self): """ @@ -378,7 +365,10 @@ async def _stop_extension(self): Private method that defines the cleanup code to run when the server is stopping. """ - # TODO: explore if cleanup is necessary + secrets_manager = self.settings.get("jai_secrets_manager", None) + + if secrets_manager: + secrets_manager.stop() def _init_persona_manager( self, room_id: str, ychat: YChat @@ -447,7 +437,6 @@ def _link_jupyter_server_extension(self, server_app: ServerApp): ".git", # Git version control directory ".venv", # Python virtual environment directory "venv", # Python virtual environment directory - ".env", # Environment variable files "node_modules", # Node.js dependencies directory ".pytest_cache", # PyTest cache directory ".mypy_cache", # MyPy type checker cache directory diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index fefba68c5..db81d6b65 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -1,5 +1,3 @@ -from typing import TYPE_CHECKING, Optional, cast - from jupyter_ai.config_manager import ConfigManager, KeyEmptyError, WriteConflictError from jupyter_server.base.handlers import APIHandler as BaseAPIHandler from pydantic import ValidationError @@ -7,125 +5,6 @@ from tornado.web import HTTPError from .config import UpdateConfigRequest -from .models import ( - ListProvidersEntry, - ListProvidersResponse, -) - -if TYPE_CHECKING: - from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider - from jupyter_ai_magics.providers import BaseProvider - - -class ProviderHandler(BaseAPIHandler): - """ - Helper base class used for HTTP handlers hosting endpoints relating to - providers. Wrapper around BaseAPIHandler. - """ - - @property - def lm_providers(self) -> dict[str, "BaseProvider"]: - return self.settings["lm_providers"] - - @property - def em_providers(self) -> dict[str, "BaseEmbeddingsProvider"]: - return self.settings["em_providers"] - - @property - def allowed_models(self) -> Optional[list[str]]: - return self.settings["allowed_models"] - - @property - def blocked_models(self) -> Optional[list[str]]: - return self.settings["blocked_models"] - - def _filter_blocked_models(self, providers: list[ListProvidersEntry]): - """ - Satisfy the model-level allow/blocklist by filtering models accordingly. - The provider-level allow/blocklist is already handled in - `AiExtension.initialize_settings()`. - """ - if self.blocked_models is None and self.allowed_models is None: - return providers - - def filter_predicate(local_model_id: str): - model_id = provider.id + ":" + local_model_id - if self.blocked_models: - return model_id not in self.blocked_models - else: - return model_id in cast(list, self.allowed_models) - - # filter out every model w/ model ID according to allow/blocklist - for provider in providers: - provider.models = list(filter(filter_predicate, provider.models or [])) - provider.chat_models = list( - filter(filter_predicate, provider.chat_models or []) - ) - provider.completion_models = list( - filter(filter_predicate, provider.completion_models or []) - ) - - # filter out every provider with no models which satisfy the allow/blocklist, then return - return filter((lambda p: len(p.models) > 0), providers) - - -class ModelProviderHandler(ProviderHandler): - @web.authenticated - def get(self): - providers = [] - - # Step 1: gather providers - for provider in self.lm_providers.values(): - optionals = {} - if provider.model_id_label: - optionals["model_id_label"] = provider.model_id_label - - providers.append( - ListProvidersEntry( - id=provider.id, - name=provider.name, - models=provider.models, - chat_models=provider.chat_models(), - completion_models=provider.completion_models(), - help=provider.help, - auth_strategy=provider.auth_strategy, - registry=provider.registry, - fields=provider.fields, - **optionals, - ) - ) - - # Step 2: sort & filter providers - providers = self._filter_blocked_models(providers) - providers = sorted(providers, key=lambda p: p.name) - - # Finally, yield response. - response = ListProvidersResponse(providers=providers) - self.finish(response.model_dump_json()) - - -class EmbeddingsModelProviderHandler(ProviderHandler): - @web.authenticated - def get(self): - providers = [] - for provider in self.em_providers.values(): - providers.append( - ListProvidersEntry( - id=provider.id, - name=provider.name, - models=provider.models, - help=provider.help, - auth_strategy=provider.auth_strategy, - registry=provider.registry, - fields=provider.fields, - ) - ) - - providers = self._filter_blocked_models(providers) - providers = sorted(providers, key=lambda p: p.name) - - response = ListProvidersResponse(providers=providers) - self.finish(response.model_dump_json()) class GlobalConfigHandler(BaseAPIHandler): @@ -165,19 +44,6 @@ def post(self): ) from e -class ApiKeysHandler(BaseAPIHandler): - @property - def config_manager(self) -> ConfigManager: # type:ignore[override] - return self.settings["jai_config_manager"] - - @web.authenticated - def delete(self, api_key_name: str): - try: - self.config_manager.delete_api_key(api_key_name) - except Exception as e: - raise HTTPError(500, str(e)) - - class InterruptStreamingHandler(BaseAPIHandler): """Interrupt a current message streaming""" diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py deleted file mode 100644 index f2d6d1007..000000000 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import Optional - -from jupyterlab_chat.models import Message as JChatMessage -from jupyterlab_chat.ychat import YChat -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage - - -class YChatHistory(BaseChatMessageHistory): - """ - An implementation of `BaseChatMessageHistory` that returns the preceding `k` - exchanges (`k * 2` messages) from the given YChat model. - - If `k` is set to `None`, then this class returns all preceding messages. - - TODO: Consider just defining `k` as the number of messages and default to 4. - """ - - def __init__(self, ychat: YChat, k: Optional[int] = None): - self.ychat = ychat - self.k = k - - @property - def messages(self) -> list[BaseMessage]: # type:ignore[override] - """ - Returns the last `2 * k` messages preceding the latest message. If - `k` is set to `None`, return all preceding messages. - """ - # TODO: consider bounding history based on message size (e.g. total - # char/token count) instead of message count. - all_messages = self.ychat.get_messages() - - # gather last k * 2 messages and return - # we exclude the last message since that is the human message just - # submitted by a user. - start_idx = 0 if self.k is None else -2 * self.k - 1 - recent_messages: list[JChatMessage] = all_messages[start_idx:-1] - - return self._convert_to_langchain_messages(recent_messages) - - def _convert_to_langchain_messages(self, jchat_messages: list[JChatMessage]): - """ - Accepts a list of Jupyter Chat messages, and returns them as a list of - LangChain messages. - """ - messages: list[BaseMessage] = [] - for jchat_message in jchat_messages: - if jchat_message.sender.startswith("jupyter-ai-personas::"): - messages.append(AIMessage(content=jchat_message.body)) - else: - messages.append(HumanMessage(content=jchat_message.body)) - - return messages - - def add_message(self, message: BaseMessage) -> None: - # do nothing when other LangChain objects call this method, since - # message history is maintained by the `YChat` shared document. - return - - def clear(self): - raise NotImplementedError() diff --git a/packages/jupyter-ai/jupyter_ai/model_providers/model_handlers.py b/packages/jupyter-ai/jupyter_ai/model_providers/model_handlers.py new file mode 100644 index 000000000..da3f8f382 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/model_providers/model_handlers.py @@ -0,0 +1,28 @@ +from jupyter_server.base.handlers import APIHandler as BaseAPIHandler +from pydantic import BaseModel +from tornado import web + +from .model_list import CHAT_MODELS + + +class ChatModelEndpoint(BaseAPIHandler): + """ + A handler class that defines the `/api/ai/models/chat` endpoint. + + - `GET /api/ai/models/chat`: returns list of all chat models. + + - `GET /api/ai/models/chat?id=`: returns info on that model (TODO) + """ + + @web.authenticated + def get(self): + response = ListChatModelsResponse(chat_models=CHAT_MODELS) + self.finish(response.model_dump_json()) + + +class ListChatModelsResponse(BaseModel): + chat_models: list[str] + + +class ListEmbeddingModelsResponse(BaseModel): + embedding_models: list[str] diff --git a/packages/jupyter-ai/jupyter_ai/model_providers/model_list.py b/packages/jupyter-ai/jupyter_ai/model_providers/model_list.py new file mode 100644 index 000000000..acd899bcb --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/model_providers/model_list.py @@ -0,0 +1,36 @@ +from litellm import all_embedding_models, models_by_provider + +chat_model_ids = [] +embedding_model_ids = [] +embedding_model_set = set(all_embedding_models) + +for provider_name in models_by_provider: + for model_name in models_by_provider[provider_name]: + model_name: str = model_name + + if model_name.startswith(f"{provider_name}/"): + model_id = model_name + else: + model_id = f"{provider_name}/{model_name}" + + is_embedding = ( + model_name in embedding_model_set + or model_id in embedding_model_set + or "embed" in model_id + ) + + if is_embedding: + embedding_model_ids.append(model_id) + else: + chat_model_ids.append(model_id) + + +CHAT_MODELS = sorted(chat_model_ids) +""" +List of chat model IDs, following the `litellm` syntax. +""" + +EMBEDDING_MODELS = sorted(embedding_model_ids) +""" +List of embedding model IDs, following the `litellm` syntax. +""" diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 1c54e0e0f..ff9dae735 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,12 +1,13 @@ from typing import Optional -from jupyter_ai_magics.providers import AuthStrategy, Field from pydantic import BaseModel DEFAULT_CHUNK_SIZE = 2000 DEFAULT_CHUNK_OVERLAP = 100 +# TODO: Delete this once the new Models API can return these properties. +# This is just being kept as a reference. class ListProvidersEntry(BaseModel): """Model provider with supported models and provider's authentication strategy @@ -17,22 +18,8 @@ class ListProvidersEntry(BaseModel): model_id_label: Optional[str] = None models: list[str] help: Optional[str] = None - auth_strategy: AuthStrategy + # auth_strategy: AuthStrategy registry: bool - fields: list[Field] + # fields: list[Field] chat_models: Optional[list[str]] = None completion_models: Optional[list[str]] = None - - -class ListProvidersResponse(BaseModel): - providers: list[ListProvidersEntry] - - -class IndexedDir(BaseModel): - path: str - chunk_size: int = DEFAULT_CHUNK_SIZE - chunk_overlap: int = DEFAULT_CHUNK_OVERLAP - - -class IndexMetadata(BaseModel): - dirs: list[IndexedDir] diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 985e40078..310901b1c 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator + from litellm import ModelResponseStream + from .persona_manager import PersonaManager @@ -233,10 +235,13 @@ def as_user_dict(self) -> dict[str, Any]: user = self.as_user() return asdict(user) - async def stream_message(self, reply_stream: "AsyncIterator") -> None: + async def stream_message( + self, reply_stream: "AsyncIterator[ModelResponseStream | str]" + ) -> None: """ Takes an async iterator, dubbed the 'reply stream', and streams it to a - new message by this persona in the YChat. + new message by this persona in the YChat. The async iterator may yield + either strings or `litellm.ModelResponseStream` objects. Details: - Creates a new message upon receiving the first chunk from the reply stream, then continuously updates it until the stream is closed. @@ -248,6 +253,15 @@ async def stream_message(self, reply_stream: "AsyncIterator") -> None: try: self.awareness.set_local_state_field("isWriting", True) async for chunk in reply_stream: + # Coerce LiteLLM stream chunk to a string delta + if not isinstance(chunk, str): + chunk = chunk.choices[0].delta.content + + # LiteLLM streams always terminate with an empty chunk, so we + # ignore and continue when this occurs. + if not chunk: + continue + if ( stream_id and stream_id in self.message_interrupted.keys() diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 495dd5c7b..633ca085e 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,12 +1,14 @@ -from typing import Any +from typing import Any, Optional from jupyterlab_chat.models import Message -from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables.history import RunnableWithMessageHistory +from litellm import acompletion -from ...history import YChatHistory from ..base_persona import BasePersona, PersonaDefaults -from .prompt_template import JUPYTERNAUT_PROMPT_TEMPLATE, JupyternautVariables +from ..persona_manager import SYSTEM_USERNAME +from .prompt_template import ( + JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, + JupyternautSystemPromptArgs, +) class JupyternautPersona(BasePersona): @@ -27,40 +29,72 @@ def defaults(self): ) async def process_message(self, message: Message) -> None: - if not self.config_manager.lm_provider: + if not self.config_manager.chat_model: self.send_message( - "No language model provider configured. Please set one in the Jupyter AI settings." + "No chat model is configured.\n\n" + "You must set one first in the Jupyter AI settings, found in 'Settings > AI Settings' from the menu bar." ) return - provider_name = self.config_manager.lm_provider.name - model_id = self.config_manager.lm_provider_params["model_id"] + model_id = self.config_manager.chat_model + context_as_messages = self.get_context_as_messages(model_id, message) + response_aiter = await acompletion( + model=model_id, + messages=[ + *context_as_messages, + { + "role": "user", + "content": message.body, + }, + ], + stream=True, + ) - # Process file attachments and include their content in the context - context = self.process_attachments(message) + await self.stream_message(response_aiter) - runnable = self.build_runnable() - variables = JupyternautVariables( - input=message.body, + def get_context_as_messages( + self, model_id: str, message: Message + ) -> list[dict[str, Any]]: + """ + Returns the current context, including attachments and recent messages, + as a list of messages accepted by `litellm.acompletion()`. + """ + system_msg_args = JupyternautSystemPromptArgs( model_id=model_id, - provider_name=provider_name, persona_name=self.name, - context=context, - ) - variables_dict = variables.model_dump() - reply_stream = runnable.astream(variables_dict) - await self.stream_message(reply_stream) - - def build_runnable(self) -> Any: - # TODO: support model parameters. maybe we just add it to lm_provider_params in both 2.x and 3.x - llm = self.config_manager.lm_provider(**self.config_manager.lm_provider_params) - runnable = JUPYTERNAUT_PROMPT_TEMPLATE | llm | StrOutputParser() - - runnable = RunnableWithMessageHistory( - runnable=runnable, # type:ignore[arg-type] - get_session_history=lambda: YChatHistory(ychat=self.ychat, k=2), - input_messages_key="input", - history_messages_key="history", - ) + context=self.process_attachments(message), + ).model_dump() + + system_msg = { + "role": "system", + "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args), + } + + context_as_messages = [system_msg, *self._get_history_as_messages()] + return context_as_messages + + def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]: + """ + Returns the current history as a list of messages accepted by + `litellm.acompletion()`. + """ + # TODO: consider bounding history based on message size (e.g. total + # char/token count) instead of message count. + all_messages = self.ychat.get_messages() + + # gather last k * 2 messages and return + # we exclude the last message since that is the human message just + # submitted by a user. + start_idx = 0 if k is None else -2 * k - 1 + recent_messages: list[Message] = all_messages[start_idx:-1] + + history: list[dict[str, Any]] = [] + for msg in recent_messages: + role = ( + "assistant" + if msg.sender.startswith("jupyter-ai-personas::") + else "system" if msg.sender == SYSTEM_USERNAME else "user" + ) + history.append({"role": role, "content": msg.body}) - return runnable + return history diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py index cd5a6ce9b..05cb7b956 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py @@ -1,11 +1,6 @@ from typing import Optional -from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, - SystemMessagePromptTemplate, -) +from jinja2 import Template from pydantic import BaseModel _JUPYTERNAUT_SYSTEM_PROMPT_FORMAT = """ @@ -17,7 +12,7 @@ When installed, Jupyter AI adds a chat experience in JupyterLab that allows multiple users to collaborate with one or more agents like yourself. -You are not a language model, but rather an AI agent powered by a foundation model `{{model_id}}`, provided by '{{provider_name}}'. +You are not a language model, but rather an AI agent powered by a foundation model `{{model_id}}`. You are receiving a request from a user in JupyterLab. Your goal is to fulfill this request to the best of your ability. @@ -48,28 +43,13 @@ """.strip() -JUPYTERNAUT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template( - _JUPYTERNAUT_SYSTEM_PROMPT_FORMAT, template_format="jinja2" - ), - MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), - ] -) - -class JupyternautVariables(BaseModel): - """ - Variables expected by `JUPYTERNAUT_PROMPT_TEMPLATE`, defined as a Pydantic - data model for developer convenience. +JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE: Template = Template( + _JUPYTERNAUT_SYSTEM_PROMPT_FORMAT +) - Call the `.model_dump()` method on an instance to convert it to a Python - dictionary. - """ - input: str +class JupyternautSystemPromptArgs(BaseModel): persona_name: str - provider_name: str model_id: str context: Optional[str] = None diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py new file mode 100644 index 000000000..782dd4533 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_manager.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +import asyncio +import os +from datetime import datetime +from io import StringIO +from typing import TYPE_CHECKING + +from dotenv import dotenv_values, load_dotenv +from tornado.web import HTTPError +from traitlets.config import LoggingConfigurable + +from .secrets_types import SecretsList +from .secrets_utils import build_updated_dotenv + +if TYPE_CHECKING: + import logging + from typing import Any + + from jupyter_server.services.contents.filemanager import AsyncFileContentsManager + + from ..extension import AiExtension + + +class EnvSecretsManager(LoggingConfigurable): + """ + The default secrets manager implementation. + + TODO: Create a `BaseSecretsManager` class and add an + `AiExtension.secrets_manager_class` configurable trait to allow custom + implementations. + + TODO: Add a `EnvSecretsManager.dotenv_path` configurable trait to allow + users to change the path of the `.env` file. + + TODO: Add a `EnvSecretsManager.envvar_secrets` configurable trait to allow + users to pass a list of glob expressions that define which environment + variables are listed as secrets in the UI. This should default to `*TOKEN*, + *SECRET*, *KEY*`. + """ + + parent: AiExtension + """ + The parent `AiExtension` class. + + NOTE: This attribute is automatically set by the `LoggingConfigurable` + parent class. This annotation exists only to help type checkers like `mypy`. + """ + + log: logging.Logger + """ + The logger used by by this instance. + + NOTE: This attribute is automatically set by the `LoggingConfigurable` + parent class. This annotation exists only to help type checkers like `mypy`. + """ + + _last_modified: datetime | None + """ + The 'last modified' timestamp on the '.env' file retrieved in the previous + tick of the `_watch_dotenv()` background task. + """ + + _initial_env: dict[str, str] + """ + Dictionary containing the initial environment variables passed to this + process. Set to `dict(os.environ)` exactly once on init. + + This attribute should not be set more than once, since secrets loaded from + the `.env` file are added to `os.environ` after this class initializes. + """ + + _dotenv_env: dict[str, str] + """ + Dictionary containing the environment variables defined in the `.env` file. + If no `.env` file exists, this will be an empty dictionary. This attribute + is continuously updated via the `_watch_dotenv()` background task. + """ + + _dotenv_lock: asyncio.Lock + """ + Lock which must be held while reading or writing to the `.env` file from the + `ContentsManager`. + """ + + @property + def contents_manager(self) -> AsyncFileContentsManager: + return self.parent.serverapp.contents_manager + + @property + def event_loop(self) -> asyncio.AbstractEventLoop: + return self.parent.event_loop + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Set instance attributes + self._last_modified = None + self._initial_env = dict(os.environ) + self._dotenv_env = {} + self._dotenv_lock = asyncio.Lock() + + # Start `_watch_dotenv()` task to automatically update the environment + # variables when `.env` is modified + self._watch_dotenv_task = self.event_loop.create_task(self._watch_dotenv()) + + async def _watch_dotenv(self) -> None: + """ + Watches the `.env` file and automatically responds to changes. + """ + while True: + await asyncio.sleep(2) + + # Fetch file content and its last modified timestamp + try: + async with self._dotenv_lock: + dotenv_file = await self.contents_manager.get(".env", content=True) + dotenv_content = dotenv_file.get("content") + assert isinstance(dotenv_content, str) + except HTTPError as e: + # Continue if file does not exist, otherwise re-raise + if e.status_code == 404: + self._handle_dotenv_notfound() + continue + except Exception: + self.log.exception("Unknown exception in `_watch_dotenv()`:") + continue + + # Continue if the `.env` file was already processed and its content + # is unchanged. + if self._last_modified == dotenv_file["last_modified"]: + continue + + # When this line is reached, the .env file needs to be applied. + # Log a statement accordingly, ensure the new `.env` file is listed + # in `.gitignore`, and store the latest last modified timestamp. + if self._last_modified: + # Statement when .env file was modified: + self.log.info( + "Detected changes to the '.env' file. Re-applying '.env' to the environment..." + ) + else: + # Statement when the .env file was just created, or when this is + # the first iteration and a .env file already exists: + self.log.info( + "Detected '.env' file at the workspace root. Applying '.env' to the environment..." + ) + self.event_loop.create_task(self._ensure_dotenv_gitignored()) + self._last_modified = dotenv_file["last_modified"] + + # Apply the latest `.env` file to the environment. + # See `self._apply_dotenv()` for more info. + self._apply_dotenv(dotenv_content) + + def _apply_dotenv(self, content: str) -> None: + """ + Applies a `.env` file to the environment given its content. This method: + + 1. Resets any variables removed from the `.env` file. See + `self._reset_envvars()` for more info. + + 2. Stores the parsed `.env` file as a dictionary in `self._dotenv_env`. + + 3. Sets the environment variables in `os.environ` as defined in the + `.env` file. + """ + # Parse the latest `.env` file and store it in `self._dotenv_env`, + # tracking deleted environment variables in `deleted_envvars`. + new_dotenv_env = dotenv_values(stream=StringIO(content)) + new_dotenv_env = {k: v for k, v in new_dotenv_env.items() if v != None} + deleted_envvars = [k for k in self._dotenv_env if k not in new_dotenv_env] + self._dotenv_env = new_dotenv_env + + # Apply the new `.env` file to the environment and reset all + # environment variables in `deleted_envvars`. + if deleted_envvars: + self._reset_envvars(deleted_envvars) + self.log.info( + f"Removed {len(deleted_envvars)} variables from the environment as they were removed from '.env'." + ) + load_dotenv(stream=StringIO(content), override=True) + self.log.info("Applied '.env' to the environment.") + + async def _ensure_dotenv_gitignored(self) -> bool: + """ + Ensures the `.env` file is listed in the `.gitignore` file at the + workspace root, creating/updating the `.gitignore` file to list `.env` + if needed. + + This method is called by the `_watch_dotenv()` background task either on + the first iteration when the `.env` file already exists, or when the + `.env` file was created on a subsequent iteration. + """ + # Fetch `.gitignore` file. + gitignore_file: dict[str, Any] | None = None + try: + gitignore_file = await self.contents_manager.get(".gitignore", content=True) + except HTTPError as e: + # Continue if file does not exist, otherwise re-raise + if e.status_code == 404: + pass + else: + raise e + except Exception: + self.log.exception("Unknown exception raised when fetching `.gitignore`:") + + # Return early if the `.gitignore` file exists and already lists `.env`. + old_content: str = (gitignore_file or {}).get("content", "") + if ".env\n" in old_content: + return + + # Otherwise, log something and create/update the `.gitignore` file to + # list `.env`. + self.log.info("Updating `.gitignore` file to include `.env`...") + new_lines = "# Ignore secrets in '.env'\n.env\n" + new_content = old_content + "\n" + new_lines if old_content else new_lines + try: + gitignore_file = await self.contents_manager.save( + { + "type": "file", + "format": "text", + "mimetype": "text/plain", + "content": new_content, + }, + ".gitignore", + ) + except Exception: + self.log.exception("Unknown exception raised when updating `.gitignore`:") + self.log.info("Updated `.gitignore` file to include `.env`.") + + def _reset_envvars(self, names: list[str]) -> None: + """ + Resets each environment variable in the given list. Each variable is + restored to its initial value in `self._initial_env` if present, and + deleted from `os.environ` otherwise. + """ + for ev_name in names: + if ev_name in self._initial_env: + os.environ[ev_name] = self._initial_env.get(ev_name) + else: + del os.environ[ev_name] + + def _handle_dotenv_notfound(self) -> None: + """ + Method called by the `_watch_dotenv()` task when the `.env` file is + not found. + """ + if self._last_modified: + self._last_modified = None + if self._dotenv_env: + self._reset_envvars(list(self._dotenv_env.keys())) + self._dotenv_env = {} + + def list_secrets(self) -> SecretsList: + """ + Lists the names of each environment variable from the workspace `.env` + file and the environment variables passed to the Python process. Notes: + + 1. For envvars from the Python process (not set in `.env`), only + environment variables whose names contain "KEY" or "TOKEN" or "SECRET" + are included. + + 2. Each envvar listed in `.env` is included in the returned list. + """ + dotenv_secrets_names = set() + process_secrets_names = set() + + # Add secrets from the initial environment + for name in self._initial_env.keys(): + if "KEY" in name or "TOKEN" in name or "SECRET" in name: + process_secrets_names.add(name) + + # Add secrets from .env, if any + for name in self._dotenv_env: + dotenv_secrets_names.add(name) + + # Remove `TIKTOKEN_CACHE_DIR`, which is set in the initial environment + # by some other package and is not a secret. + # This gets included otherwise since it contains 'TOKEN' in its name. + process_secrets_names.discard("TIKTOKEN_CACHE_DIR") + + return SecretsList( + editable_secrets=sorted(list(dotenv_secrets_names)), + static_secrets=sorted(list(process_secrets_names)), + ) + + async def update_secrets( + self, + updated_secrets: dict[str, str | None], + ) -> None: + """ + Accepts a dictionary of secrets to update, adds/updates/deletes them + from `.env` accordingly, and applies the updated `.env` file to the + environment. Notes: + + - A new `.env` file is created if it does not exist. + + - If the value of a secret in `updated_secrets` is `None`, then the + secret is deleted from `.env`. + + - Otherwise, the secret is added/updated in `.env`. + + - A best effort is made at preserving the formatting in the `.env` + file. However, inline comments following a environment variable + definition on the same line will be deleted. + """ + # Return early if passed an empty dictionary + if not updated_secrets: + return + + # Hold the lock during the entire duration of the update + async with self._dotenv_lock: + # Fetch `.env` file content, storing its raw content in + # `dotenv_content` and its parsed value as a dict in `dotenv_env`. + dotenv_content: str = "" + try: + dotenv_file = await self.contents_manager.get(".env", content=True) + if "content" in dotenv_file: + dotenv_content = dotenv_file["content"] + assert isinstance(dotenv_content, str) + except HTTPError as e: + # Continue if file does not exist, otherwise re-raise + if e.status_code == 404: + pass + else: + raise e + except Exception: + self.log.exception( + "Unknown exception raised when reading `.env` in response to an update:" + ) + + # Build the new `.env` file using these variables. + # See `build_updated_dotenv()` for more info on how this is done. + new_dotenv_content = build_updated_dotenv(dotenv_content, updated_secrets) + + # Return early if no changes are needed in `.env`. + if new_dotenv_content is None: + return + + # Save new content + try: + dotenv_file = await self.contents_manager.save( + { + "type": "file", + "format": "text", + "mimetype": "text/plain", + "content": new_dotenv_content, + }, + ".env", + ) + last_modified = dotenv_file.get("last_modified") + assert isinstance(last_modified, datetime) + except Exception: + self.log.exception("Unknown exception raised when updating `.env`:") + + # If this is a new file, ensure the `.env` file is listed in `.gitignore`. + # `self._last_modified == None` should imply the `.env` file did not exist. + if not self._last_modified: + self.event_loop.create_task(self._ensure_dotenv_gitignored()) + + # Update last modified timestamp and apply the new environment. + self._last_modified = last_modified + # This automatically sets `self._dotenv_env`. + self._apply_dotenv(new_dotenv_content) + self.log.info("Updated secrets in `.env`.") + + def get_secret(self, secret_name: str) -> str | None: + """ + Returns the value of a secret given its name. The returned secret must + NEVER be shared with frontend clients! + """ + # TODO + + def stop(self) -> None: + """ + Stops this instance and any background tasks spawned by this instance. + This method should be called if and only if the server is shutting down. + """ + self._watch_dotenv_task.cancel() diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py new file mode 100644 index 000000000..c2492f315 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_rest_api.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from jupyter_server.base.handlers import APIHandler as BaseAPIHandler +from tornado.web import HTTPError, authenticated + +from .secrets_types import UpdateSecretsRequest + +if TYPE_CHECKING: + from .secrets_manager import EnvSecretsManager + + +class SecretsRestAPI(BaseAPIHandler): + """ + Defines the REST API served at the `/api/ai/secrets` endpoint. + + Methods supported: + + - `GET secrets/`: Returns a list of secrets. + - `PUT secrets/`: Add/update/delete a set of secrets. + """ + + @property + def secrets_manager(self) -> EnvSecretsManager: # type:ignore[override] + return self.settings["jai_secrets_manager"] + + @authenticated + def get(self): + response = self.secrets_manager.list_secrets() + self.set_status(200) + self.finish(response.model_dump_json()) + + @authenticated + async def put(self): + try: + # Validate the request body matches the `UpdateSecretsRequest` + # expected type + request = UpdateSecretsRequest(**self.get_json_body()) + + # Dispatch the request to the secrets manager + await self.secrets_manager.update_secrets(request.updated_secrets) + except Exception as e: + self.log.exception("Exception raised when handling PUT /api/ai/secrets/:") + raise HTTPError(500, str(e)) + + self.set_status(204) + self.finish() diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py new file mode 100644 index 000000000..8959ee8dc --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_types.py @@ -0,0 +1,35 @@ +from typing import Optional + +from pydantic import BaseModel + + +class SecretsList(BaseModel): + """ + The response type returned by `GET /api/ai/secrets`. + + The response fields only include the names of each secret, and never must + never include the value of any secret. + """ + + editable_secrets: list[str] = [] + """ + List of secrets set in the `.env` file. These secrets can be edited. + """ + + static_secrets: list[str] = [] + """ + List of secrets passed as environment variables to the Python process or + passed as traitlets configuration to JupyterLab. These secrets cannot be + edited. + + Environment variables passed to the Python process are only included if + their name contains 'KEY' or 'TOKEN' or 'SECRET'. + """ + + +class UpdateSecretsRequest(BaseModel): + """ + The request body expected by `PUT /api/ai/secrets`. + """ + + updated_secrets: dict[str, Optional[str]] diff --git a/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py b/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py new file mode 100644 index 000000000..f4887e708 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/secrets/secrets_utils.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from io import StringIO +from typing import TYPE_CHECKING + +from dotenv import dotenv_values +from dotenv.parser import parse_stream + +if TYPE_CHECKING: + import logging + +ENVIRONMENT_VAR_REGEX = "^([a-zA-Z_][a-zA-Z_0-9]*)=?(['\"])?" +""" +Regex that matches a environment variable definition. +""" + + +def build_updated_dotenv( + dotenv_content: str, + updated_secrets: dict[str, str | None], + log: logging.Logger | None = None, +) -> str | None: + """ + Accepts the existing `.env` file as a parsed dictionary of environment + variables, along with a dictionary of secrets to update. `None` values + indicate the secret should be deleted. Otherwise, the secret will be added + to or updated in `.env`. + + This function returns the content of the updated `.env` file as a string, + and returns `None` if no updates to `.env` are required. + + NOTE: This function currently deletes inline comments on environment + variable definitions. This may be fixed in a future update. + """ + # Return early if no updates were given. + if not updated_secrets: + return None + + # Parse content of `.env` into a dictionary of environment variables + dotenv_env = dotenv_values(stream=StringIO(dotenv_content)) + + # Define `secrets_to_add`, `secrets_to_update`, and + # `secrets_to_remove`. + secrets_to_add: dict[str, str] = {} + secrets_to_update: dict[str, str] = {} + secrets_to_remove: set[str] = set() + if dotenv_env: + for name, value in updated_secrets.items(): + # Case 1: secret should be added to `.env` + if value is not None and dotenv_env.get(name, None) == None: + secrets_to_add[name] = value + continue + # Case 2: secret should be updated in `.env` + if value is not None and dotenv_env.get(name, None) != None: + secrets_to_update[name] = value + continue + # Case 3: secret should be removed from `.env` + if value is None and dotenv_env.get(name, None) != None: + secrets_to_remove.add(name) + continue + else: + # Case 4: keys can only be added when a `.env` file is not + # present. + secrets_to_add = {k: v for k, v in updated_secrets.items() if v is not None} + + # Return early if update has effect. + if not (secrets_to_add or secrets_to_update or secrets_to_remove): + return None + + # First, handle the case of adding secrets to a new `.env` file. + if not dotenv_env: + new_content = "" + max_i = len(secrets_to_add) - 1 + for i, (name, value) in enumerate(secrets_to_add.items()): + new_content += f'{name}="{value}"\n' + if i != max_i: + new_content += "\n" + + return new_content + + # Now handle the case of updating an existing `.env` file. + # To preserve formatting, multiline variables, and inline comments on + # variable defintions, we re-use the parser used by `python_dotenv`. + # It is not trivial to re-implement their parser. + # + # Algorithm overview: + # + # 1. The `parse_stream()` function returns an Iterator that yields 'Binding' + # objects that represent 'parsed chunks' of a `.env` file. Each chunk may + # contain: + # + # - An environment variable definition (`Binding.key is not None`), + # - An invalid line (`Binding.error == True`), + # - A standalone comment (if neither condition applies). + # + # 2. (Case 1) Invalid lines and environment variable bindings listed in + # `secrets_to_remove` are ignored. + # + # 3. (Case 2) Environment variable definitions listed in `secrets_to_update` + # are appended to `new_content` with the new value. + # + # 4. (Case 3) All other `Binding` objects are appended to `new_content` as-is. + # + # 5. Finally, new environment variables listed in `secrets_to_add` are + # appended at the end after the `.env` file is fully parsed. + new_content = "" + for binding in parse_stream(StringIO(dotenv_content)): + # Case 1 + if binding.error or binding.key in secrets_to_remove: + continue + # Case 2 + if binding.key in secrets_to_update: + name = binding.key + # extra logic to preserve formatting as best as we can + whitespace_before, whitespace_after = get_whitespace_around( + binding.original.string + ) + value = secrets_to_update[name] + new_content += whitespace_before + new_content += f'{name}="{value}"' + new_content += whitespace_after + continue + # Case 3 + new_content += binding.original.string + + if secrets_to_add: + # Ensure new secrets get put at least 2 lines below the rest + if not new_content.endswith("\n"): + new_content += "\n\n" + elif not new_content.endswith("\n\n"): + new_content += "\n" + + max_i = len(secrets_to_add) - 1 + for i, (name, value) in enumerate(secrets_to_add.items()): + new_content += f'{name}="{value}"\n' + if i != max_i: + new_content += "\n" + + return new_content + + +def get_whitespace_around(text: str) -> tuple[str, str]: + """ + Extract whitespace prefix and suffix from a string. + + Args: + text: The input string + + Returns: + A tuple of (prefix, suffix) where prefix is the leading whitespace + and suffix is the trailing whitespace + """ + if not text: + return ("", "") + + # Find prefix (leading whitespace) + prefix_end = 0 + for i, char in enumerate(text): + if not char.isspace(): + prefix_end = i + break + else: + # String is all whitespace + return (text, "") + + # Find suffix (trailing whitespace) + suffix_start = len(text) + for i in range(len(text) - 1, -1, -1): + if not text[i].isspace(): + suffix_start = i + 1 + break + + prefix = text[:prefix_end] + suffix = text[suffix_start:] + + return (prefix, suffix) diff --git a/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py b/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py new file mode 100644 index 000000000..2fb4ef2e9 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/secrets/test_secrets_utils.py @@ -0,0 +1,127 @@ +from jupyter_ai.secrets.secrets_utils import build_updated_dotenv, get_whitespace_around + + +class TestGetWhitespaceAround: + """Test cases for get_whitespace_around function.""" + + def test_empty_string(self): + """Test with empty string.""" + prefix, suffix = get_whitespace_around("") + assert prefix == "" + assert suffix == "" + + def test_no_whitespace(self): + """Test with string containing no whitespace.""" + prefix, suffix = get_whitespace_around("hello") + assert prefix == "" + assert suffix == "" + + def test_only_prefix_whitespace(self): + """Test with string containing only leading whitespace.""" + prefix, suffix = get_whitespace_around(" hello") + assert prefix == " " + assert suffix == "" + + def test_only_suffix_whitespace(self): + """Test with string containing only trailing whitespace.""" + prefix, suffix = get_whitespace_around("hello ") + assert prefix == "" + assert suffix == " " + + def test_both_prefix_and_suffix_whitespace(self): + """Test with string containing both leading and trailing whitespace.""" + prefix, suffix = get_whitespace_around(" hello ") + assert prefix == " " + assert suffix == " " + + def test_mixed_whitespace_types(self): + """Test with mixed whitespace types (spaces, tabs, newlines).""" + prefix, suffix = get_whitespace_around(" \t\nhello\n\t ") + assert prefix == " \t\n" + assert suffix == "\n\t " + + def test_all_whitespace(self): + """Test with string containing only whitespace.""" + prefix, suffix = get_whitespace_around(" ") + assert prefix == " " + assert suffix == "" + + def test_single_character(self): + """Test with single non-whitespace character.""" + prefix, suffix = get_whitespace_around("x") + assert prefix == "" + assert suffix == "" + + def test_single_whitespace_character(self): + """Test with single whitespace character.""" + prefix, suffix = get_whitespace_around(" ") + assert prefix == " " + assert suffix == "" + + +class TestBuildUpdatedDotenv: + """Test cases for build_updated_dotenv function.""" + + def test_empty_updates(self): + """Test with no updates to make.""" + result = build_updated_dotenv("KEY=value", {}) + assert result is None + + def test_add_to_empty_dotenv(self): + """Test adding secrets to empty dotenv content.""" + result = build_updated_dotenv("", {"NEW_KEY": "new_value"}) + assert result == 'NEW_KEY="new_value"\n' + + def test_add_multiple_to_empty_dotenv(self): + """Test adding multiple secrets to empty dotenv content.""" + result = build_updated_dotenv("", {"KEY1": "value1", "KEY2": "value2"}) + expected_lines = result.strip().split("\n") + assert len(expected_lines) == 3 # Two keys plus one empty line + assert 'KEY1="value1"' in expected_lines + assert 'KEY2="value2"' in expected_lines + + def test_update_existing_key(self): + """Test updating an existing key.""" + dotenv_content = 'EXISTING_KEY="old_value"\n' + result = build_updated_dotenv(dotenv_content, {"EXISTING_KEY": "new_value"}) + assert 'EXISTING_KEY="new_value"' in result + + def test_add_new_key_to_existing_dotenv(self): + """Test adding a new key to existing dotenv content.""" + dotenv_content = 'EXISTING_KEY="existing_value"\n' + result = build_updated_dotenv(dotenv_content, {"NEW_KEY": "new_value"}) + assert 'EXISTING_KEY="existing_value"' in result + assert 'NEW_KEY="new_value"' in result + + def test_remove_existing_key(self): + """Test removing an existing key.""" + dotenv_content = 'KEY_TO_REMOVE="value"\nKEY_TO_KEEP="value"\n' + result = build_updated_dotenv(dotenv_content, {"KEY_TO_REMOVE": None}) + assert "KEY_TO_REMOVE" not in result + assert 'KEY_TO_KEEP="value"' in result + + def test_mixed_operations(self): + """Test adding, updating, and removing keys in one operation.""" + dotenv_content = 'UPDATE_ME="old"\nREMOVE_ME="gone"\nKEEP_ME="same"\n' + updates = {"UPDATE_ME": "new", "REMOVE_ME": None, "ADD_ME": "added"} + result = build_updated_dotenv(dotenv_content, updates) + + assert 'UPDATE_ME="new"' in result + assert "REMOVE_ME" not in result + assert 'KEEP_ME="same"' in result + assert 'ADD_ME="added"' in result + + def test_preserve_comments_and_empty_lines(self): + """Test that comments and empty lines are preserved.""" + dotenv_content = '# This is a comment\nKEY="value"\n\n# Another comment\n' + result = build_updated_dotenv(dotenv_content, {"NEW_KEY": "new_value"}) + + assert "# This is a comment" in result + assert "# Another comment" in result + assert 'KEY="value"' in result + assert 'NEW_KEY="new_value"' in result + + def test_delete_last_secret(self): + dotenv_content = "KEY='value'" + result = build_updated_dotenv(dotenv_content, {"KEY": None}) + assert isinstance(result, str) and result.strip() == "" diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index b3d00159e..b8b33045d 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -1,53 +1,33 @@ import json -from types import SimpleNamespace + +# from types import SimpleNamespace from typing import Union import pytest -from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler -from jupyter_ai.completions.models import ( +from jupyter_ai.completions.completion_types import ( InlineCompletionReply, InlineCompletionRequest, InlineCompletionStreamChunk, ) -from jupyter_ai_magics import BaseProvider -from langchain_community.llms import FakeListLLM +from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler from pytest import fixture from tornado.httputil import HTTPServerRequest from tornado.web import Application -class MockProvider(BaseProvider, FakeListLLM): - id = "my_provider" - name = "My Provider" - model_id_key = "model" - models = ["model"] - raise_exc: bool = False - - def __init__(self, **kwargs): - if "responses" not in kwargs: - kwargs["responses"] = ["Test response"] - super().__init__(**kwargs) - - async def _acall(self, *args, **kwargs): - if self.raise_exc: - raise Exception("Test exception") - else: - return super()._call(*args, **kwargs) - - class MockCompletionHandler(DefaultInlineCompletionHandler): def __init__(self, lm_provider=None, lm_provider_params=None, raise_exc=False): self.request = HTTPServerRequest() self.application = Application() self.messages = [] self.tasks = [] - self.settings["jai_config_manager"] = SimpleNamespace( - completions_lm_provider=lm_provider or MockProvider, - completions_lm_provider_params=lm_provider_params or {"model_id": "model"}, - ) - self.settings["jai_event_loop"] = SimpleNamespace( - create_task=lambda x: self.tasks.append(x) - ) + # self.settings["jai_config_manager"] = SimpleNamespace( + # completions_lm_provider=lm_provider or MockProvider, + # completions_lm_provider_params=lm_provider_params or {"model_id": "model"}, + # ) + # self.settings["jai_event_loop"] = SimpleNamespace( + # create_task=lambda x: self.tasks.append(x) + # ) self.settings["model_parameters"] = {} self._llm_params = {} self._llm = None @@ -116,11 +96,11 @@ async def test_handle_request(inline_handler): ) async def test_handle_request_with_spurious_fragments(response, expected_suggestion): inline_handler = MockCompletionHandler( - lm_provider=MockProvider, - lm_provider_params={ - "model_id": "model", - "responses": [response], - }, + # lm_provider=MockProvider, + # lm_provider_params={ + # "model_id": "model", + # "responses": [response], + # }, ) dummy_request = InlineCompletionRequest( number=1, prefix="", suffix="", mime="", stream=False @@ -144,11 +124,11 @@ async def test_handle_request_with_spurious_fragments_stream( response, expected_suggestion ): inline_handler = MockCompletionHandler( - lm_provider=MockProvider, - lm_provider_params={ - "model_id": "model", - "responses": [response], - }, + # lm_provider=MockProvider, + # lm_provider_params={ + # "model_id": "model", + # "responses": [response], + # }, ) dummy_request = InlineCompletionRequest( number=1, prefix="", suffix="", mime="", stream=True @@ -164,11 +144,11 @@ async def test_handle_request_with_spurious_fragments_stream( async def test_handle_stream_request(): inline_handler = MockCompletionHandler( - lm_provider=MockProvider, - lm_provider_params={ - "model_id": "model", - "responses": ["test"], - }, + # lm_provider=MockProvider, + # lm_provider_params={ + # "model_id": "model", + # "responses": ["test"], + # }, ) dummy_request = InlineCompletionRequest( number=1, prefix="", suffix="", mime="", stream=True @@ -198,12 +178,12 @@ async def test_handle_stream_request(): async def test_handle_request_with_error(inline_handler): inline_handler = MockCompletionHandler( - lm_provider=MockProvider, - lm_provider_params={ - "model_id": "model", - "responses": ["test"], - "raise_exc": True, - }, + # lm_provider=MockProvider, + # lm_provider_params={ + # "model_id": "model", + # "responses": ["test"], + # "raise_exc": True, + # }, ) dummy_request = InlineCompletionRequest( number=1, prefix="", suffix="", mime="", stream=True diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 66c48f789..c142fa500 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -11,7 +11,6 @@ KeyInUseError, WriteConflictError, ) -from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from pydantic import ValidationError @@ -43,12 +42,8 @@ def config_file_with_model_fields(jp_data_dir): def common_cm_kwargs(config_path): """Kwargs that are commonly used when initializing the CM.""" log = logging.getLogger() - lm_providers = get_lm_providers() - em_providers = get_em_providers() return { "log": log, - "lm_providers": lm_providers, - "em_providers": em_providers, "config_path": config_path, "allowed_providers": None, "blocked_providers": None, @@ -467,8 +462,6 @@ def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields config_path = config_file_with_model_fields log = logging.getLogger() - lm_providers = get_lm_providers() - em_providers = get_em_providers() defaults = { "model_provider_id": None, @@ -480,8 +473,6 @@ def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields cm = ConfigManager( log=log, - lm_providers=lm_providers, - em_providers=em_providers, config_path=config_path, defaults=defaults, ) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py index 3d731fb5a..16948b663 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -1,11 +1,8 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from unittest import mock import pytest from jupyter_ai.extension import AiExtension -from jupyter_ai_magics import BaseProvider -from langchain_core.messages import BaseMessage pytest_plugins = ["pytest_jupyter.jupyter_server"] @@ -57,36 +54,4 @@ def jp_server_config(jp_server_config): @pytest.fixture def ai_extension(jp_serverapp): - ai = AiExtension() - # `BaseProvider.server_settings` can be only initialized once; however, the tests - # may run in parallel setting it with race condition; because we are not testing - # the `BaseProvider.server_settings` here, we can just mock the setter - settings_mock = mock.PropertyMock() - with mock.patch.object(BaseProvider, "server_settings", settings_mock): - yield ai - - -@pytest.mark.parametrize( - "max_history,messages_to_add,expected_size", - [ - # for max_history = 1 we expect to see up to 2 messages (1 human and 1 AI message) - (1, 4, 2), - # if there is less than `max_history` messages, all should be returned - (1, 1, 1), - # if no limit is set, all messages should be returned - (None, 9, 9), - ], -) -@pytest.mark.skip("TODO v3: replace this with a unit test for YChatHistory") -def test_max_chat_history(ai_extension, max_history, messages_to_add, expected_size): - ai = ai_extension - ai.default_max_chat_history = max_history - ai.initialize_settings() - for i in range(messages_to_add): - message = BaseMessage( - content=f"Test message {i}", - type="test", - ) - ai.settings["llm_chat_memory"].add_message(message) - - assert len(ai.settings["llm_chat_memory"].messages) == expected_size + AiExtension() diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 075b68003..57c037b71 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -36,6 +36,9 @@ dependencies = [ # NOTE: Make sure to update the corresponding dependency in # `packages/jupyter-ai/package.json` to match the version range below "jupyterlab-chat>=0.16.0,<0.17.0", + "litellm>=1.73,<2", + "jinja2>=3.0,<4", + "python_dotenv>=1,<2", ] dynamic = ["version", "description", "authors", "urls", "keywords"] @@ -58,7 +61,7 @@ test = [ dev = ["jupyter_ai_magics[dev]"] -all = ["jupyter_ai_magics[all]", "pypdf", "arxiv"] +all = ["jupyter_ai_magics[all]"] [tool.hatch.version] source = "nodejs" diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 1f3337f85..ea4018d3e 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -1,109 +1,35 @@ -import React, { useEffect, useState, useMemo } from 'react'; +import React, { useEffect, useState } from 'react'; import { Box } from '@mui/system'; -import { - Alert, - Button, - IconButton, - FormControl, - FormControlLabel, - FormLabel, - MenuItem, - Radio, - RadioGroup, - TextField, - Tooltip, - CircularProgress -} from '@mui/material'; +import { IconButton, Tooltip } from '@mui/material'; import SettingsIcon from '@mui/icons-material/Settings'; import WarningAmberIcon from '@mui/icons-material/WarningAmber'; -import { Select } from './select'; -import { AiService } from '../handler'; -import { ModelFields } from './settings/model-fields'; -import { ServerInfoState, useServerInfo } from './settings/use-server-info'; -import { ExistingApiKeys } from './settings/existing-api-keys'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; -import { minifyUpdate } from './settings/minify'; -import { useStackingAlert } from './mui-extras/stacking-alert'; -import { RendermimeMarkdown } from './settings/rendermime-markdown'; import { IJaiCompletionProvider } from '../tokens'; -import { getProviderId, getModelLocalId } from '../utils'; +import { ModelIdInput } from './settings/model-id-input'; +// import { ModelParametersInput } from './settings/model-parameters-input'; +import { SecretsSection } from './settings/secrets-section'; type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; completionProvider: IJaiCompletionProvider | null; openInlineCompleterSettings: () => void; - // The temporary input options, should be removed when jupyterlab chat is - // the only chat. - inputOptions?: boolean; }; /** * Component that returns the settings view in the chat panel. */ export function ChatSettings(props: ChatSettingsProps): JSX.Element { - // state fetched on initial render - const server = useServerInfo(); - - // initialize alert helper - const alert = useStackingAlert(); - const apiKeysAlert = useStackingAlert(); - - // user inputs - const [lmProvider, setLmProvider] = - useState(null); - const [emProvider, setEmProvider] = - useState(null); - const [clmProvider, setClmProvider] = - useState(null); - const [showLmLocalId, setShowLmLocalId] = useState(false); - const [showEmLocalId, setShowEmLocalId] = useState(false); - const [showClmLocalId, setShowClmLocalId] = useState(false); - const [chatHelpMarkdown, setChatHelpMarkdown] = useState(null); - const [embeddingHelpMarkdown, setEmbeddingHelpMarkdown] = useState< - string | null - >(null); - const [completionHelpMarkdown, setCompletionHelpMarkdown] = useState< - string | null - >(null); - const [lmLocalId, setLmLocalId] = useState(''); - const [emLocalId, setEmLocalId] = useState(''); - const [clmLocalId, setClmLocalId] = useState(''); - - const lmGlobalId = useMemo(() => { - if (!lmProvider) { - return null; - } - - return lmProvider.id + ':' + lmLocalId; - }, [lmProvider, lmLocalId]); - - const emGlobalId = useMemo(() => { - if (!emProvider) { - return null; - } - - return emProvider.id + ':' + emLocalId; - }, [emProvider, emLocalId]); - - const clmGlobalId = useMemo(() => { - if (!clmProvider) { - return null; - } - - return clmProvider.id + ':' + clmLocalId; - }, [clmProvider, clmLocalId]); - - const [apiKeys, setApiKeys] = useState>({}); - const [sendWse, setSendWse] = useState(false); - const [lmFields, setLmFields] = useState>({}); - const [emFields, setEmFields] = useState>({}); - const [clmFields, setClmFields] = useState>({}); + const [completionModel, setCompletionModel] = useState(null); const [isCompleterEnabled, setIsCompleterEnabled] = useState( props.completionProvider && props.completionProvider.isEnabled() ); + /** + * Effect: Listen to JupyterLab completer settings updates on initial render + * and update the `isCompleterEnabled` state variable accordingly. + */ useEffect(() => { const refreshCompleterState = () => { setIsCompleterEnabled( @@ -118,551 +44,76 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { }; }, [props.completionProvider]); - // whether the form is currently saving - const [saving, setSaving] = useState(false); - - /** - * Effect: initialize inputs after fetching server info. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready) { - return; - } - - setLmLocalId(server.chat.lmLocalId); - setEmLocalId(server.chat.emLocalId); - setClmLocalId(server.completions.lmLocalId); - setSendWse(server.config.send_with_shift_enter); - setChatHelpMarkdown(server.chat.lmProvider?.help ?? null); - setEmbeddingHelpMarkdown(server.chat.emProvider?.help ?? null); - setCompletionHelpMarkdown(server.completions.lmProvider?.help ?? null); - if (server.chat.lmProvider?.registry) { - setShowLmLocalId(true); - } - if (server.chat.emProvider?.registry) { - setShowEmLocalId(true); - } - if (server.completions.lmProvider?.registry) { - setShowClmLocalId(true); - } - setLmProvider(server.chat.lmProvider); - setClmProvider(server.completions.lmProvider); - setEmProvider(server.chat.emProvider); - }, [server]); - - /** - * Effect: re-initialize apiKeys object whenever the selected LM/EM changes. - * Properties with a value of '' indicate necessary user input. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready) { - return; - } - - const newApiKeys: Record = {}; - const lmAuth = lmProvider?.auth_strategy; - const emAuth = emProvider?.auth_strategy; - if ( - lmAuth?.type === 'env' && - !server.config.api_keys.includes(lmAuth.name) - ) { - newApiKeys[lmAuth.name] = ''; - } - if (lmAuth?.type === 'multienv') { - lmAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - if ( - emAuth?.type === 'env' && - !server.config.api_keys.includes(emAuth.name) - ) { - newApiKeys[emAuth.name] = ''; - } - if (emAuth?.type === 'multienv') { - emAuth.names.forEach(apiKey => { - if (!server.config.api_keys.includes(apiKey)) { - newApiKeys[apiKey] = ''; - } - }); - } - - setApiKeys(newApiKeys); - }, [lmProvider, emProvider, server]); - - /** - * Effect: re-initialize fields object whenever the selected LM changes. - */ - useEffect(() => { - if (server.state !== ServerInfoState.Ready || !lmGlobalId) { - return; - } - - const currFields: Record = - server.config.fields?.[lmGlobalId] ?? {}; - setLmFields(currFields); - - if (!emGlobalId) { - return; - } - - const initEmbeddingModelFields: Record = - server.config.embeddings_fields?.[emGlobalId] ?? {}; - setEmFields(initEmbeddingModelFields); - - if (!clmGlobalId) { - return; - } - - const initCompleterModelFields: Record = - server.config.completions_fields?.[clmGlobalId] ?? {}; - setClmFields(initCompleterModelFields); - }, [server, lmGlobalId, emGlobalId, clmGlobalId]); - - const handleSave = async () => { - // compress fields with JSON values - if (server.state !== ServerInfoState.Ready) { - return; - } - - for (const fieldKey in lmFields) { - const fieldVal = lmFields[fieldKey]; - if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { - continue; - } - - try { - const parsedFieldVal = JSON.parse(fieldVal); - const compressedFieldVal = JSON.stringify(parsedFieldVal); - lmFields[fieldKey] = compressedFieldVal; - } catch (e) { - continue; - } - } - - for (const fieldKey in emFields) { - const fieldVal = emFields[fieldKey]; - if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { - continue; - } - - try { - const parsedFieldVal = JSON.parse(fieldVal); - const compressedFieldVal = JSON.stringify(parsedFieldVal); - emFields[fieldKey] = compressedFieldVal; - } catch (e) { - continue; - } - } - - for (const fieldKey in clmFields) { - const fieldVal = clmFields[fieldKey]; - if (typeof fieldVal !== 'string' || !fieldVal.trim().startsWith('{')) { - continue; - } - - try { - const parsedFieldVal = JSON.parse(fieldVal); - const compressedFieldVal = JSON.stringify(parsedFieldVal); - clmFields[fieldKey] = compressedFieldVal; - } catch (e) { - continue; - } - } - - let updateRequest: AiService.UpdateConfigRequest = { - model_provider_id: lmGlobalId, - embeddings_provider_id: emGlobalId, - api_keys: apiKeys, - fields: lmGlobalId ? { [lmGlobalId]: lmFields } : {}, - completions_fields: clmGlobalId ? { [clmGlobalId]: clmFields } : {}, - embeddings_fields: emGlobalId ? { [emGlobalId]: emFields } : {}, - completions_model_provider_id: clmGlobalId, - send_with_shift_enter: sendWse - }; - updateRequest = minifyUpdate(server.config, updateRequest); - updateRequest.last_read = server.config.last_read; - - setSaving(true); - try { - await apiKeysAlert.clear(); - await AiService.updateConfig(updateRequest); - } catch (e) { - console.error(e); - const msg = - e instanceof Error || typeof e === 'string' - ? e.toString() - : 'An unknown error occurred. Check the console for more details.'; - alert.show('error', msg); - return; - } finally { - setSaving(false); - } - await server.refetchAll(); - alert.show('success', 'Settings saved successfully.'); - }; - - if (server.state === ServerInfoState.Loading) { - return ( - - - - ); - } - - if (server.state === ServerInfoState.Error) { - return ( - - - {server.error || - 'An unknown error occurred. Check the console for more details.'} - - - ); - } - return ( - {/* Chat language model section */} -

- Language model -

- - {server.lmProviders.providers - .map(lmp => lmp.chat_models.length) - .reduce((partialSum, num) => partialSum + num, 0) > 0 ? ( - - - {showLmLocalId && ( - setLmLocalId(e.target.value)} - fullWidth - /> - )} - {chatHelpMarkdown && ( - - )} - {lmGlobalId && ( - - )} - - ) : ( -

No language models available.

- )} - - {/* Embedding model section */} -

Embedding model

- {server.emProviders.providers.length > 0 ? ( - - - {showEmLocalId && ( - setEmLocalId(e.target.value)} - fullWidth - /> - )} - {embeddingHelpMarkdown && ( - - )} - {emGlobalId && ( - - )} - - ) : ( -

No embedding models available.

- )} + {/* SECTION: Embedding model */} + {/* TODO */} - {/* Completer language model section */} + {/* SECTION: Completion model */}

- Inline completions model + Completion model

- {server.lmProviders.providers - .map(lmp => lmp.completion_models.length) - .reduce((partialSum, num) => partialSum + num, 0) > 0 ? ( - - - {showClmLocalId && ( - setClmLocalId(e.target.value)} - fullWidth - /> - )} - {completionHelpMarkdown && ( - - )} - {clmGlobalId && ( - - )} - - ) : ( -

No Inline Completion models.

- )} - - {/* API Keys section */} -

API Keys

- - {Object.entries(apiKeys).length === 0 && - server.config.api_keys.length === 0 ? ( -

No API keys are required by the selected models.

- ) : null} - - {/* API key inputs for newly-used providers */} - {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( - - setApiKeys(apiKeys => ({ - ...apiKeys, - [apiKeyName]: e.target.value - })) - } - /> - ))} - {/* Pre-existing API keys */} - + Configure the language model used to generate inline completions when + editing documents in JupyterLab. +

+ { + setCompletionModel(latestChatModelId); + }} /> - {/* Input - to remove when jupyterlab chat is the only chat */} - {(props.inputOptions ?? true) && ( - <> -

Input

- - - When writing a message, press Enter to: - - { - setSendWse(e.target.value === 'newline'); - }} - > - } - label="Send the message" - /> - } - label={ - <> - Start a new line (use Shift+Enter to - send) - - } - /> - - - - )} + {/* Model parameters section */} + {/*

Model parameters

+

Configure additional parameters for the language model.

+ */} - - - - {alert.jsx} + {/* SECTION: Secrets (and API keys) */} +

Secrets and API keys

+
); } function CompleterSettingsButton(props: { - selection: AiService.ListProvidersEntry | null; + hasCompletionModel: boolean; provider: IJaiCompletionProvider | null; isCompleterEnabled: boolean | null; openSettings: () => void; }): JSX.Element { - if (props.selection && !props.isCompleterEnabled) { + if (props.hasCompletionModel && !props.isCompleterEnabled) { return ( @@ -672,19 +123,10 @@ function CompleterSettingsButton(props: { ); } return ( - + ); } - -function getProvider( - globalModelId: string, - providers: AiService.ListProvidersResponse -): AiService.ListProvidersEntry | null { - const providerId = getProviderId(globalModelId); - const provider = providers.providers.find(p => p.id === providerId); - return provider ?? null; -} diff --git a/packages/jupyter-ai/src/components/mui-extras/async-icon-button.tsx b/packages/jupyter-ai/src/components/mui-extras/async-icon-button.tsx index 5de6b3913..14207f681 100644 --- a/packages/jupyter-ai/src/components/mui-extras/async-icon-button.tsx +++ b/packages/jupyter-ai/src/components/mui-extras/async-icon-button.tsx @@ -1,9 +1,9 @@ -import React, { useMemo, useState } from 'react'; +import React, { useState } from 'react'; import { Box, CircularProgress, IconButton } from '@mui/material'; import { ContrastingTooltip } from './contrasting-tooltip'; type AsyncIconButtonProps = { - onClick: () => Promise; + onClick: () => 'canceled' | Promise; onError: (emsg: string) => unknown; onSuccess: () => unknown; children: JSX.Element; @@ -23,18 +23,21 @@ type AsyncIconButtonProps = { export function AsyncIconButton(props: AsyncIconButtonProps): JSX.Element { const [loading, setLoading] = useState(false); const [showConfirm, setShowConfirm] = useState(false); - const shouldConfirm = useMemo(() => !!props.confirm, []); async function handleClick() { - if (shouldConfirm && !showConfirm) { + if (props.confirm && !showConfirm) { setShowConfirm(true); return; } - setLoading(true); let thrown = false; try { - await props.onClick(); + const promise = props.onClick(); + if (promise === 'canceled') { + return; + } + setLoading(true); + await promise; } catch (e: unknown) { thrown = true; if (e instanceof Error) { @@ -42,10 +45,14 @@ export function AsyncIconButton(props: AsyncIconButtonProps): JSX.Element { } else { // this should never happen. // if this happens, it means the thrown value was not of type `Error`. - props.onError('Unknown error occurred.'); + console.error(e); + props.onError( + 'Unknown error occurred. Check the browser console logs.' + ); } + } finally { + setLoading(false); } - setLoading(false); if (!thrown) { props.onSuccess(); } diff --git a/packages/jupyter-ai/src/components/settings/existing-api-keys.tsx b/packages/jupyter-ai/src/components/settings/existing-api-keys.tsx deleted file mode 100644 index ea0ee8430..000000000 --- a/packages/jupyter-ai/src/components/settings/existing-api-keys.tsx +++ /dev/null @@ -1,232 +0,0 @@ -import React, { useEffect, useCallback, useRef, useState } from 'react'; -import { - Box, - IconButton, - Typography, - TextField, - InputAdornment -} from '@mui/material'; -import Edit from '@mui/icons-material/Edit'; -import DeleteOutline from '@mui/icons-material/DeleteOutline'; -import Cancel from '@mui/icons-material/Cancel'; -import Check from '@mui/icons-material/Check'; -import Visibility from '@mui/icons-material/Visibility'; -import VisibilityOff from '@mui/icons-material/VisibilityOff'; -import { AsyncIconButton } from '../mui-extras/async-icon-button'; - -import { AiService } from '../../handler'; -import { StackingAlert } from '../mui-extras/stacking-alert'; - -export type ExistingApiKeysProps = { - alert: StackingAlert; - apiKeys: string[]; - onSuccess: () => unknown; -}; - -/** - * Component that renders a list of existing API keys. Each API key is rendered - * by a unique `ExistingApiKey` component. - */ -export function ExistingApiKeys(props: ExistingApiKeysProps): JSX.Element { - // current editable API key name, if any. - const [editableApiKey, setEditableApiKey] = useState(null); - - return ( - .MuiBox-root:not(:first-child)': { - marginTop: -2 - } - }} - > - {props.apiKeys.map(apiKey => ( - - ))} - {props.alert.jsx} - - ); -} - -type ExistingApiKeyProps = { - alert: StackingAlert; - apiKey: string; - editable: boolean; - setEditable: React.Dispatch>; - onSuccess: () => unknown; -}; - -/** - * Component that renders a single existing API key specified by `props.apiKey`. - * Includes actions for editing and deleting the API key. - */ -function ExistingApiKey(props: ExistingApiKeyProps) { - const [input, setInput] = useState(''); - const [inputVisible, setInputVisible] = useState(false); - const [error, setError] = useState(false); - const inputRef = useRef(); - - /** - * Effect: Select the input after `editable` is set to `true`. This needs to - * be done in an effect because the TextField needs to be rendered with - * `disabled=false` first. When `editable` is set to `false`, reset any - * input-related state. - */ - useEffect(() => { - if (props.editable) { - inputRef.current?.focus(); - } else { - setInput(''); - setInputVisible(false); - setError(false); - } - }, [props.editable]); - - const onEditIntent = useCallback(() => { - props.setEditable(props.apiKey); - }, []); - - const onDelete = useCallback(() => { - return AiService.deleteApiKey(props.apiKey); - }, []); - - const toggleInputVisibility = useCallback(() => { - setInputVisible(visible => !visible); - }, []); - - const onEditCancel = useCallback(() => { - props.setEditable(null); - }, []); - - const onEditSubmit = useCallback(() => { - return AiService.updateConfig({ - api_keys: { [props.apiKey]: input } - }); - }, [input]); - - const onError = useCallback( - (emsg: string) => { - props.alert.show('error', emsg); - }, - [props.alert] - ); - - const validateInput = useCallback(() => { - if (!props.editable) { - return; - } - - setError(!input); - }, [props.editable, input]); - - const onEditSuccess = useCallback(() => { - props.setEditable(null); - props.alert.show('success', 'API key updated successfully.'); - props.onSuccess(); - }, [props.alert, props.onSuccess]); - - const onDeleteSuccess = useCallback(() => { - props.alert.show('success', 'API key deleted successfully.'); - props.onSuccess(); - }, [props.alert, props.onSuccess]); - - return ( - - setInput(e.target.value)} - disabled={!props.editable} - inputRef={inputRef} - // validation props - onBlur={validateInput} - error={error} - helperText={'API key value must not be empty'} - FormHelperTextProps={{ - sx: { - visibility: error ? 'unset' : 'hidden', - margin: 0, - whiteSpace: 'nowrap' - } - }} - // style props - size="small" - variant="standard" - type={inputVisible ? 'text' : 'password'} - label={ - -
{props.apiKey}
-
- } - InputProps={{ - endAdornment: props.editable && ( - - e.preventDefault()} - > - {inputVisible ? : } - - - ) - }} - sx={{ - flexGrow: 1, - margin: 0, - '& .MuiInputBase-input': { - padding: 0, - paddingBottom: 1 - } - }} - /> - - {props.editable ? ( - // 16px margin top - 5px padding - <> - e.preventDefault()} - > - - - e.preventDefault()} - confirm={true} - > - - - - ) : ( - <> - - - - - - - - )} - -
- ); -} diff --git a/packages/jupyter-ai/src/components/settings/model-id-input.tsx b/packages/jupyter-ai/src/components/settings/model-id-input.tsx new file mode 100644 index 000000000..888788d0b --- /dev/null +++ b/packages/jupyter-ai/src/components/settings/model-id-input.tsx @@ -0,0 +1,168 @@ +import React, { useState, useEffect } from 'react'; +import { Autocomplete, TextField, Button, Box } from '@mui/material'; +import { AiService } from '../../handler'; +import { useStackingAlert } from '../mui-extras/stacking-alert'; +import Save from '@mui/icons-material/Save'; + +export type ModelIdInputProps = { + /** + * The label of the model ID input field. + */ + label: string; + + /** + * The "type" of the model being configured. This prop should control the API + * endpoints used to get the current model, set the current model, and + * retrieve model ID suggestions. + */ + modality: 'chat' | 'completion'; + + /** + * (optional) The placeholder text shown within the model ID input field. + */ + placeholder?: string; + + /** + * (optional) Whether to render in full width. Defaults to `true`. + */ + fullWidth?: boolean; + + /** + * (optional) Callback that is run when the component retrieves the current + * model ID _or_ successfully updates the model ID. Details: + * + * - This callback is run once when the current model ID is retrieved from the + * backend, with `initial=true`. Any model ID updates made through this + * component run this callback with `initial=false`. + * + * - This callback will not run if an exception was raised while updating the + * model ID. + */ + onModelIdFetch?: (modelId: string | null, initial: boolean) => unknown; +}; + +/** + * A model ID input. + */ +export function ModelIdInput(props: ModelIdInputProps): JSX.Element { + const [models, setModels] = useState([]); + const [prevModel, setPrevModel] = useState(null); + const [loading, setLoading] = useState(true); + const [updating, setUpdating] = useState(false); + + const [input, setInput] = useState(''); + const alert = useStackingAlert(); + + /** + * Effect: Fetch list of models and current model on initial render, based on + * the modality. + */ + useEffect(() => { + async function loadData() { + try { + let modelsResponse: string[]; + let currModelResponse: string | null; + + if (props.modality === 'chat') { + [modelsResponse, currModelResponse] = await Promise.all([ + AiService.listChatModels(), + AiService.getChatModel() + ]); + } else if (props.modality === 'completion') { + [modelsResponse, currModelResponse] = await Promise.all([ + AiService.listChatModels(), + AiService.getCompletionModel() + ]); + } else { + throw new Error(`Unrecognized model modality '${props.modality}'.`); + } + + setModels(modelsResponse); + setPrevModel(currModelResponse); + setInput(currModelResponse ?? ''); + } catch (error) { + console.error('Failed to load chat models:', error); + setModels([]); + } finally { + setLoading(false); + } + } + + loadData(); + }, []); + + const handleUpdateChatModel = async () => { + setUpdating(true); + try { + // perform correct REST API call based on model modality + const newModelId = input.trim() || null; + if (props.modality === 'chat') { + await AiService.updateChatModel(newModelId); + } else if (props.modality === 'completion') { + await AiService.updateCompletionModel(newModelId); + } else { + throw new Error(`Unrecognized model modality '${props.modality}'.`); + } + + // update local state and run parent callback + setPrevModel(newModelId); + props.onModelIdFetch?.(newModelId, true); + + // show success alert + // TODO: maybe just use the JL Notifications API + alert.show( + 'success', + newModelId + ? `Successfully updated ${props.modality} model to '${input.trim()}'.` + : `Successfully cleared ${props.modality} model.` + ); + } catch (error) { + console.error(`Failed to update ${props.modality} model:`, error); + const msg = + error instanceof Error ? error.message : 'An unknown error occurred'; + alert.show('error', `Failed to update ${props.modality} model: ${msg}`); + } finally { + setUpdating(false); + } + }; + + return ( + + { + // This condition prevents whitespace from being inserted in the model + // ID by accident. + if (newValue !== null && !newValue.includes(' ')) { + setInput(newValue); + } + }} + renderInput={params => ( + + )} + /> + + {alert.jsx} + + ); +} diff --git a/packages/jupyter-ai/src/components/settings/model-parameters-input.tsx b/packages/jupyter-ai/src/components/settings/model-parameters-input.tsx new file mode 100644 index 000000000..c6cefbd11 --- /dev/null +++ b/packages/jupyter-ai/src/components/settings/model-parameters-input.tsx @@ -0,0 +1,220 @@ +import React, { useState } from 'react'; +import { Button, TextField, Box, Alert, IconButton } from '@mui/material'; +import DeleteIcon from '@mui/icons-material/Delete'; + +interface ModelParameter { + id: string; + name: string; + type: string; + value: string; + isStatic?: boolean; +} + +interface StaticParameterDef { + name: string; + type: string; + label: string; +} + +// Add some common fields as static parameters here +const STATIC_PARAMETERS: StaticParameterDef[] = [ + { name: 'temperature', type: 'float', label: 'Temperature' }, + { name: 'api_url', type: 'string', label: 'API URL' }, + { name: 'max_tokens', type: 'integer', label: 'Max Tokens' } +]; + +export function ModelParametersInput(): JSX.Element { + const [parameters, setParameters] = useState([]); + const [validationError, setValidationError] = useState(''); + + const handleAddParameter = () => { + const newParameter: ModelParameter = { + id: Date.now().toString(), + name: '', + type: '', + value: '', + isStatic: false + }; + setParameters([...parameters, newParameter]); + setValidationError(''); + }; + + const handleAddStaticParameter = (staticParam: StaticParameterDef) => { + // Check if static parameter already exists + const exists = parameters.some( + param => param.name === staticParam.name && param.isStatic + ); + if (exists) { + setValidationError(`Parameter "${staticParam.label}" is already added`); + return; + } + const newParameter: ModelParameter = { + id: Date.now().toString(), + name: staticParam.name, + type: staticParam.type, + value: '', + isStatic: true + }; + setParameters([...parameters, newParameter]); + setValidationError(''); + }; + // For when user changes their parameter + const handleParameterChange = ( + id: string, + field: keyof ModelParameter, + value: string + ) => { + setParameters(prev => + prev.map(param => + param.id === id ? { ...param, [field]: value } : param + ) + ); + setValidationError(''); + }; + + // For when user deletes parameter + const handleDeleteParameter = (id: string) => { + setParameters(prev => prev.filter(param => param.id !== id)); + setValidationError(''); + }; + + const handleSaveParameters = () => { + // Validation: Check if any parameter has a value but missing name or type (only for custom parameters) + const invalidParams = parameters.filter( + param => + param.value.trim() !== '' && + !param.isStatic && + (param.name.trim() === '' || param.type.trim() === '') + ); + + if (invalidParams.length > 0) { + setValidationError( + 'Parameter value specified but name or type is missing' + ); + return; + } + + // Filter out parameters with empty values + const validParams = parameters.filter(param => param.value.trim() !== ''); + + // Creates JSON object of valid parameters ONLY if all 3 fields are given valid inputs + const paramsObject = validParams.reduce((acc, param) => { + acc[param.name] = param.value; + return acc; + }, {} as Record); + + // Logs the JSON object of its input state to the browser console + console.log('Model Parameters:', paramsObject); + }; + + const showSaveButton = parameters.length > 0; + const availableStaticParams = STATIC_PARAMETERS.filter( + staticParam => + !parameters.some( + param => param.name === staticParam.name && param.isStatic + ) + ); + + return ( + + + {/* Static parameter buttons */} + {availableStaticParams.length > 0 && ( + + + Common parameters: + + + {availableStaticParams.map(staticParam => ( + + ))} + + + )} + + {parameters.map(param => ( + + + handleParameterChange(param.id, 'name', e.target.value) + } + size="small" + sx={{ flex: 1 }} + disabled={param.isStatic} + InputProps={{ + readOnly: param.isStatic + }} + /> + + handleParameterChange(param.id, 'type', e.target.value) + } + size="small" + sx={{ flex: 1 }} + disabled={param.isStatic} + InputProps={{ + readOnly: param.isStatic + }} + /> + + handleParameterChange(param.id, 'value', e.target.value) + } + size="small" + sx={{ flex: 1 }} + /> + handleDeleteParameter(param.id)} + color="error" + size="small" + sx={{ ml: 1 }} + > + + + + ))} + + {validationError && ( + + {validationError} + + )} + + {showSaveButton && ( + + )} + + ); +} diff --git a/packages/jupyter-ai/src/components/settings/secrets-input.tsx b/packages/jupyter-ai/src/components/settings/secrets-input.tsx new file mode 100644 index 000000000..65decc910 --- /dev/null +++ b/packages/jupyter-ai/src/components/settings/secrets-input.tsx @@ -0,0 +1,463 @@ +import React, { useEffect, useCallback, useRef, useState } from 'react'; +import { + Box, + IconButton, + Typography, + TextField, + InputAdornment, + Button +} from '@mui/material'; +import Edit from '@mui/icons-material/Edit'; +import DeleteOutline from '@mui/icons-material/DeleteOutline'; +import Cancel from '@mui/icons-material/Cancel'; +import Check from '@mui/icons-material/Check'; +import Visibility from '@mui/icons-material/Visibility'; +import VisibilityOff from '@mui/icons-material/VisibilityOff'; +import Add from '@mui/icons-material/Add'; +import { AsyncIconButton } from '../mui-extras/async-icon-button'; + +import { AiService } from '../../handler'; +import { StackingAlert, useStackingAlert } from '../mui-extras/stacking-alert'; + +export type SecretsInputProps = { + editableSecrets: string[]; + reloadSecrets: () => unknown; +}; + +/** + * Component that renders a list of editable secrets. Each secret is + * rendered by a unique `EditableSecret` component. + */ +export function SecretsInput(props: SecretsInputProps): JSX.Element | null { + const alert = useStackingAlert(); + const [isAddingSecret, setIsAddingSecret] = useState(false); + + if (!props.editableSecrets) { + return null; + } + + const onAddSecretClick = useCallback(() => { + setIsAddingSecret(true); + }, []); + + const onAddSecretCancel = useCallback(() => { + setIsAddingSecret(false); + }, []); + + const onAddSecretSuccess = useCallback(() => { + setIsAddingSecret(false); + alert.show('success', 'Secret added successfully.'); + props.reloadSecrets(); + }, [alert, props.reloadSecrets]); + + const onAddSecretError = useCallback( + (emsg: string) => { + alert.show('error', emsg); + }, + [alert] + ); + + return ( + + {/* SUBSECTION: Editable secrets */} + {props.editableSecrets.length > 0 ? ( + .MuiBox-root:not(:first-child)': { + marginTop: -2 + } + }} + > + {props.editableSecrets.map(secret => ( + + ))} + + ) : ( + + + No secrets configured + + + Click "Add secret" to add an API key and start using Jupyternaut + with your preferred model provider. + + + )} + + {/* Add secret button */} + {isAddingSecret ? ( + + ) : ( + + )} + + {/* Info shown to the user after adding/updating a secret */} + {alert.jsx} + + ); +} + +export type EditableSecretProps = { + alert: StackingAlert; + secret: string; + reloadSecrets: () => unknown; +}; + +/** + * Component that renders a single editable secret specified by `props.apiKey`. + * Includes actions for editing and deleting the secret. + */ +export function EditableSecret(props: EditableSecretProps) { + const [input, setInput] = useState(''); + const [inputVisible, setInputVisible] = useState(false); + const [error, setError] = useState(false); + const [editable, setEditable] = useState(false); + const inputRef = useRef(); + + /** + * Effect: Select the input after `editable` is set to `true` and clear the + * input after `editable` is set to `false`. + */ + useEffect(() => { + if (editable) { + inputRef.current?.focus(); + } else { + setInput(''); + setInputVisible(false); + setError(false); + } + }, [editable]); + + const onEditIntent = useCallback(() => { + setEditable(true); + }, []); + + const onDelete = useCallback(() => { + return AiService.deleteSecret(props.secret); + }, []); + + const toggleInputVisibility = useCallback(() => { + setInputVisible(visible => !visible); + }, []); + + const onEditCancel = useCallback(() => { + setEditable(false); + }, []); + + const onEditSubmit = useCallback(() => { + // If input is empty, defocus the input to show a validation error and + // return early. + if (input.length === 0) { + inputRef.current?.blur(); + return 'canceled'; + } + + // Otherwise dispatch the request to the backend. + return AiService.updateSecrets({ + [props.secret]: input + }); + }, [input]); + + const onEditError = useCallback( + (emsg: string) => { + props.alert.show('error', emsg); + }, + [props.alert] + ); + + const validateInput = useCallback(() => { + if (!editable) { + return; + } + + setError(!input); + }, [editable, input]); + + const onEditSuccess = useCallback(() => { + setEditable(false); + props.alert.show('success', 'API key updated successfully.'); + props.reloadSecrets(); + }, [props.alert, props.reloadSecrets]); + + const onDeleteSuccess = useCallback(() => { + props.alert.show('success', 'API key deleted successfully.'); + props.reloadSecrets(); + }, [props.alert, props.reloadSecrets]); + + return ( + + setInput(e.target.value)} + disabled={!editable} + inputRef={inputRef} + // validation props + onBlur={validateInput} + error={error} + helperText={'Secret value must not be empty'} + placeholder="Secret value" + FormHelperTextProps={{ + sx: { + visibility: error ? 'unset' : 'hidden', + margin: 0, + whiteSpace: 'nowrap' + } + }} + // style props + size="small" + variant="standard" + type={inputVisible ? 'text' : 'password'} + label={ + +
{props.secret}
+
+ } + InputProps={{ + endAdornment: editable && ( + + e.preventDefault()} + > + {inputVisible ? : } + + + ) + }} + sx={{ + flexGrow: 1, + margin: 0, + '& .MuiInputBase-input': { + padding: 0, + paddingBottom: 1 + } + }} + /> + + {editable ? ( + // If this secret is being edited, show the "cancel edit" and "apply + // edit" buttons. + <> + e.preventDefault()} + > + + + e.preventDefault()} + confirm={false} + > + + + + ) : ( + // Otherwise, show the "edit secret" and "delete secret" buttons. + <> + + + + + + + + )} + +
+ ); +} + +export type NewSecretInputProps = { + alert: StackingAlert; + onCancel: () => void; + onSuccess: () => void; + onError: (emsg: string) => void; +}; + +export function NewSecretInput(props: NewSecretInputProps) { + const [secretName, setSecretName] = useState(''); + const [secretValue, setSecretValue] = useState(''); + const [secretValueVisible, setSecretValueVisible] = useState(false); + const [nameError, setNameError] = useState(false); + const [valueError, setValueError] = useState(false); + const nameInputRef = useRef(); + const valueInputRef = useRef(); + + useEffect(() => { + nameInputRef.current?.focus(); + }, []); + + const toggleSecretValueVisibility = useCallback(() => { + setSecretValueVisible(visible => !visible); + }, []); + + const validateInputs = useCallback(() => { + const nameEmpty = !secretName.trim(); + const valueEmpty = !secretValue.trim(); + setNameError(nameEmpty); + setValueError(valueEmpty); + return !nameEmpty && !valueEmpty; + }, [secretName, secretValue]); + + const onSubmit = useCallback(() => { + if (!validateInputs()) { + return 'canceled'; + } + + return AiService.updateSecrets({ + [secretName.trim()]: secretValue.trim() + }); + }, [secretName, secretValue, validateInputs]); + + return ( + + setSecretName(e.target.value)} + inputRef={nameInputRef} + error={nameError} + helperText={'Secret name must not be empty'} + placeholder="Secret name" + FormHelperTextProps={{ + sx: { + visibility: nameError ? 'unset' : 'hidden', + margin: 0, + whiteSpace: 'nowrap' + } + }} + size="small" + variant="standard" + label="Secret name" + onBlur={() => !secretName.trim() && setNameError(true)} + sx={{ + flexGrow: 1, + margin: 0, + '& .MuiInputBase-input': { + padding: 0, + paddingBottom: 1 + } + }} + /> + setSecretValue(e.target.value)} + inputRef={valueInputRef} + error={valueError} + helperText={'Secret value must not be empty'} + placeholder="Secret value" + FormHelperTextProps={{ + sx: { + visibility: valueError ? 'unset' : 'hidden', + margin: 0, + whiteSpace: 'nowrap' + } + }} + size="small" + variant="standard" + type={secretValueVisible ? 'text' : 'password'} + label="Secret value" + onBlur={() => !secretValue.trim() && setValueError(true)} + InputProps={{ + endAdornment: ( + + e.preventDefault()} + > + {secretValueVisible ? : } + + + ) + }} + sx={{ + flexGrow: 1, + margin: 0, + '& .MuiInputBase-input': { + padding: 0, + paddingBottom: 1 + } + }} + /> + + e.preventDefault()} + > + + + e.preventDefault()} + confirm={false} + > + + + + + ); +} diff --git a/packages/jupyter-ai/src/components/settings/secrets-list.tsx b/packages/jupyter-ai/src/components/settings/secrets-list.tsx new file mode 100644 index 000000000..de9a199f9 --- /dev/null +++ b/packages/jupyter-ai/src/components/settings/secrets-list.tsx @@ -0,0 +1,69 @@ +import React from 'react'; +import { + Box, + List, + ListItem, + ListItemIcon, + ListItemText, + Typography +} from '@mui/material'; +import LockIcon from '@mui/icons-material/Lock'; + +export type SecretsListProps = { + secrets: string[]; +}; + +/** + * Component that renders a list of secrets. This should be used to render the + * "static secrets" set by the traitlets configuration / environment variables + * passed directly to the `jupyter-lab` process. + * + * Editable secrets should be rendered using the `` component. + */ +export function SecretsList(props: SecretsListProps): JSX.Element | null { + if (!props.secrets || props.secrets.length === 0) { + return null; + } + + return ( + + {props.secrets.map((secret, index) => ( + + + + + + + {secret} + + + } + /> + + ))} + + ); +} diff --git a/packages/jupyter-ai/src/components/settings/secrets-section.tsx b/packages/jupyter-ai/src/components/settings/secrets-section.tsx new file mode 100644 index 000000000..db527d19d --- /dev/null +++ b/packages/jupyter-ai/src/components/settings/secrets-section.tsx @@ -0,0 +1,119 @@ +import React, { useEffect, useState } from 'react'; +import { Alert, Box, CircularProgress, Link } from '@mui/material'; + +import { AiService } from '../../handler'; +import { useStackingAlert } from '../mui-extras/stacking-alert'; +import { SecretsInput } from './secrets-input'; +import { SecretsList } from './secrets-list'; + +/** + * Renders the "Secrets" section in the Jupyter AI settings. + * + * - Editable secrets (stored in `.env` by default) are rendered by the + * `` component. + * + * - Static secrets are rendered by the `` component. + */ +export function SecretsSection(): JSX.Element { + const [editableSecrets, setEditableSecrets] = useState([]); + const [staticSecrets, setStaticSecrets] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(false); + const errorAlert = useStackingAlert(); + + /** + * Function that loads secrets from the Secrets REST API, setting the + * `loading` state accordingly. + */ + const loadSecrets = async () => { + try { + setLoading(true); + const secrets = await AiService.listSecrets(); + setEditableSecrets(secrets.editable_secrets); + setStaticSecrets(secrets.static_secrets); + setError(false); + } catch (error) { + setError(true); + errorAlert.show('error', error as unknown as any); + } finally { + setLoading(false); + } + }; + + /** + * Function that is like `loadSecrets`, but does not affect the `loading` + * state. This prevents the child components from being remounted. + */ + const reloadSecrets = async () => { + try { + const secrets = await AiService.listSecrets(); + setEditableSecrets(secrets.editable_secrets); + setStaticSecrets(secrets.static_secrets); + setError(false); + } catch (error) { + setError(true); + errorAlert.show('error', error as unknown as any); + } + }; + + /** + * Effect: Fetch the secrets via the Secrets REST API on initial render. + */ + useEffect(() => { + loadSecrets(); + }, []); + + if (loading) { + return ( + + + + ); + } + + if (error) { + return {errorAlert.jsx}; + } + + return ( + + {/* Editable secrets subsection */} +

+ This section shows the secrets set in the .env file at the + workspace root. For most chat models, an API key secret in{' '} + .env is required for Jupyternaut to reply in the chat. See + the{' '} + + documentation + {' '} + for information on which API key is required for your model provider. +

+

+ Click "Add secret" to add a secret to the .env file. + Secrets can also be updated by editing the .env file + directly in JupyterLab. +

+ + + {/* Static secrets subsection */} + {staticSecrets.length ? ( + +

+ The secrets below are set by the environment variables and the + traitlets configuration passed to the server process. These secrets + can only be changed either upon restarting the server or by + contacting your server administrator. +

+ +
+ ) : null} +
+ ); +} diff --git a/packages/jupyter-ai/src/components/settings/use-server-info.ts b/packages/jupyter-ai/src/components/settings/use-server-info.ts deleted file mode 100644 index e7ba03bdb..000000000 --- a/packages/jupyter-ai/src/components/settings/use-server-info.ts +++ /dev/null @@ -1,158 +0,0 @@ -import { useState, useEffect, useMemo, useCallback } from 'react'; -import { AiService } from '../../handler'; -import { getProviderId, getModelLocalId } from '../../utils'; - -type ProvidersInfo = { - lmProvider: AiService.ListProvidersEntry | null; - emProvider: AiService.ListProvidersEntry | null; - lmLocalId: string; - emLocalId: string; -}; - -type ServerInfoProperties = { - lmProviders: AiService.ListProvidersResponse; - emProviders: AiService.ListProvidersResponse; - config: AiService.DescribeConfigResponse; - chat: ProvidersInfo; - completions: Omit; -}; - -type ServerInfoMethods = { - refetchAll: () => Promise; - refetchApiKeys: () => Promise; -}; - -export enum ServerInfoState { - /** - * Server info is being fetched. - */ - Loading, - /** - * Unable to retrieve server info. - */ - Error, - /** - * Server info was loaded successfully. - */ - Ready -} - -type ServerInfoLoading = { state: ServerInfoState.Loading }; -type ServerInfoError = { - state: ServerInfoState.Error; - error: string; -}; -type ServerInfoReady = { state: ServerInfoState.Ready } & ServerInfoProperties & - ServerInfoMethods; - -type ServerInfo = ServerInfoLoading | ServerInfoError | ServerInfoReady; - -/** - * A hook that fetches the current configuration and provider lists from the - * server. Returns a `ServerInfo` object that includes methods. - */ -export function useServerInfo(): ServerInfo { - const [state, setState] = useState(ServerInfoState.Loading); - const [serverInfoProps, setServerInfoProps] = - useState(); - const [error, setError] = useState(''); - - const fetchServerInfo = useCallback(async () => { - try { - const [config, lmProviders, emProviders] = await Promise.all([ - AiService.getConfig(), - AiService.listLmProviders(), - AiService.listEmProviders() - ]); - const lmGid = config.model_provider_id; - const emGid = config.embeddings_provider_id; - const lmProvider = - lmGid === null ? null : getProvider(lmGid, lmProviders); - const emProvider = - emGid === null ? null : getProvider(emGid, emProviders); - const lmLocalId = (lmGid && getModelLocalId(lmGid)) ?? ''; - const emLocalId = (emGid && getModelLocalId(emGid)) ?? ''; - - const clmGid = config.completions_model_provider_id; - const clmProvider = - clmGid === null ? null : getProvider(clmGid, lmProviders); - const clmLocalId = (clmGid && getModelLocalId(clmGid)) ?? ''; - - setServerInfoProps({ - config, - lmProviders, - emProviders, - chat: { - lmProvider, - emProvider, - lmLocalId, - emLocalId - }, - completions: { - lmProvider: clmProvider, - lmLocalId: clmLocalId - } - }); - - setState(ServerInfoState.Ready); - } catch (e) { - console.error(e); - if (e instanceof Error) { - setError(e.toString()); - } else { - setError('An unknown error occurred.'); - } - setState(ServerInfoState.Error); - } - }, []); - - const refetchApiKeys = useCallback(async () => { - if (!serverInfoProps) { - // this should never happen. - return; - } - - const config = await AiService.getConfig(); - setServerInfoProps({ - ...serverInfoProps, - config: { - ...serverInfoProps.config, - api_keys: config.api_keys, - last_read: config.last_read - } - }); - }, [serverInfoProps]); - - /** - * Effect: fetch server info on initial render - */ - useEffect(() => { - fetchServerInfo(); - }, []); - - return useMemo(() => { - if (state === ServerInfoState.Loading) { - return { state }; - } - - if (state === ServerInfoState.Error || !serverInfoProps) { - return { state: ServerInfoState.Error, error }; - } - - return { - state, - ...serverInfoProps, - refetchAll: fetchServerInfo, - refetchApiKeys - }; - }, [state, serverInfoProps, error, refetchApiKeys]); -} - -function getProvider( - gid: string, - providers: AiService.ListProvidersResponse -): AiService.ListProvidersEntry | null { - const providerId = getProviderId(gid); - const provider = providers.providers.find(p => p.id === providerId); - return provider ?? null; -} diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 9ea68c9bc..87dc95736 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -138,6 +138,10 @@ export namespace AiService { providers: ListProvidersEntry[]; }; + export type ListChatModelsResponse = { + chat_models: string[]; + }; + export async function listLmProviders(): Promise { return requestAPI('providers'); } @@ -155,9 +159,64 @@ export namespace AiService { }); } - export async function deleteApiKey(keyName: string): Promise { - return requestAPI(`api_keys/${keyName}`, { - method: 'DELETE' + export type SecretsList = { + editable_secrets: string[]; + static_secrets: string[]; + }; + + export async function listSecrets(): Promise { + return requestAPI('secrets/', { + method: 'GET' + }); + } + + export type UpdateSecretsRequest = { + updated_secrets: Record; + }; + + export async function updateSecrets( + updatedSecrets: Record + ): Promise { + return requestAPI('secrets/', { + method: 'PUT', + body: JSON.stringify({ + updated_secrets: updatedSecrets + }) + }); + } + + export async function deleteSecret(secretName: string): Promise { + return updateSecrets({ [secretName]: null }); + } + + export async function listChatModels(): Promise { + const response = await requestAPI('models/chat/', { + method: 'GET' + }); + return response.chat_models; + } + + export async function getChatModel(): Promise { + const response = await requestAPI('config/'); + return response.model_provider_id; + } + + export async function updateChatModel(modelId: string | null): Promise { + return await updateConfig({ + model_provider_id: modelId + }); + } + + export async function getCompletionModel(): Promise { + const response = await requestAPI('config/'); + return response.completions_model_provider_id; + } + + export async function updateCompletionModel( + modelId: string | null + ): Promise { + return await updateConfig({ + completions_model_provider_id: modelId }); } } diff --git a/packages/jupyter-ai/src/widgets/settings-widget.tsx b/packages/jupyter-ai/src/widgets/settings-widget.tsx index fa3a820f6..f8821efd3 100644 --- a/packages/jupyter-ai/src/widgets/settings-widget.tsx +++ b/packages/jupyter-ai/src/widgets/settings-widget.tsx @@ -19,7 +19,6 @@ export function buildAiSettings( rmRegistry={rmRegistry} completionProvider={completionProvider} openInlineCompleterSettings={openInlineCompleterSettings} - inputOptions={false} /> ); diff --git a/packages/jupyter-ai/style/chat-settings.css b/packages/jupyter-ai/style/chat-settings.css index 9cb96539c..ae9cbf366 100644 --- a/packages/jupyter-ai/style/chat-settings.css +++ b/packages/jupyter-ai/style/chat-settings.css @@ -1,9 +1,36 @@ -.jp-ai-ChatSettings-header { +/* + * + * Selectors must be nested in `.jp-ThemedContainer` to have a higher + * specificity than selectors in rules provided by JupyterLab. + * + * See: https://jupyterlab.readthedocs.io/en/latest/extension/extension_migration.html#css-styling + * See also: https://github.com/jupyterlab/jupyter-ai/issues/1090 + */ + +.jp-ThemedContainer .jp-ai-ChatSettings { + padding: 1.5rem; + box-sizing: border-box; + height: 100%; + overflow: scroll; +} + +.jp-ThemedContainer .jp-ai-ChatSettings a { + color: var(--jp-content-link-color); + text-decoration: underline; +} + +.jp-ThemedContainer .jp-ai-ChatSettings-header { font-size: var(--jp-ui-font-size3); font-weight: 400; color: var(--jp-ui-font-color1); } -.jp-ai-ChatSettings-welcome { +.jp-ThemedContainer .jp-ai-ChatSettings-h3 { + font-size: var(--jp-ui-font-size2); + font-weight: 400; + color: var(--jp-ui-font-color1); +} + +.jp-ThemedContainer .jp-ai-ChatSettings-welcome { color: var(--jp-ui-font-color1); }