diff --git a/.gitignore b/.gitignore index a270947..8f9f5d2 100644 --- a/.gitignore +++ b/.gitignore @@ -123,3 +123,6 @@ dmypy.json # Yarn cache .yarn/ + +# For local testing +playground/ \ No newline at end of file diff --git a/jupyter_ai_jupyternaut/extension_app.py b/jupyter_ai_jupyternaut/extension_app.py index 07952d1..d7b8f7a 100644 --- a/jupyter_ai_jupyternaut/extension_app.py +++ b/jupyter_ai_jupyternaut/extension_app.py @@ -2,22 +2,30 @@ from asyncio import get_event_loop_policy from jupyter_server.extension.application import ExtensionApp from jupyter_server.serverapp import ServerApp +import os +from tornado.web import StaticFileHandler from traitlets import List, Unicode, Dict from traitlets.config import Config from typing import TYPE_CHECKING from .config import ConfigManager, ConfigRestAPI from .handlers import RouteHandler +from .jupyternaut import JupyternautPersona from .models import ChatModelsRestAPI, ModelParametersRestAPI from .secrets import EnvSecretsManager, SecretsRestAPI if TYPE_CHECKING: from asyncio import AbstractEventLoop +JUPYTERNAUT_AVATAR_PATH = str( + os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg") +) + + class JupyternautExtension(ExtensionApp): """ The Jupyternaut server extension. - + This serves several REST APIs under the `/api/jupyternaut` route. Currently, for the sake of simplicity, they are hard-coded into the Jupyternaut server extension to allow users to configure the chat model & add API keys. @@ -33,6 +41,11 @@ class JupyternautExtension(ExtensionApp): (r"api/jupyternaut/models/chat/?", ChatModelsRestAPI), (r"api/jupyternaut/model-parameters/?", ModelParametersRestAPI), (r"api/jupyternaut/secrets/?", SecretsRestAPI), + ( + r"api/jupyternaut/static/jupyternaut.svg()/?", + StaticFileHandler, + {"path": JUPYTERNAUT_AVATAR_PATH}, + ), ] allowed_providers = List( @@ -176,7 +189,7 @@ def initialize_settings(self): } # Initialize ConfigManager - self.settings["jupyternaut.config_manager"] = ConfigManager( + config_manager = ConfigManager( config=self.config, log=self.log, allowed_providers=self.allowed_providers, @@ -186,9 +199,16 @@ def initialize_settings(self): defaults=defaults, ) - # Initialize SecretsManager + # Bind ConfigManager instance to global settings dictionary + self.settings["jupyternaut.config_manager"] = config_manager + + # Bind ConfigManager instance to Jupyternaut as a class variable + JupyternautPersona.config_manager = config_manager + + # Initialize SecretsManager and bind it to global settings dictionary self.settings["jupyternaut.secrets_manager"] = EnvSecretsManager(parent=self) - + + def _link_jupyter_server_extension(self, server_app: ServerApp): """Setup custom config needed by this extension.""" c = Config() @@ -210,4 +230,3 @@ def _link_jupyter_server_extension(self, server_app: ServerApp): ] server_app.update_config(c) super()._link_jupyter_server_extension(server_app) - diff --git a/jupyter_ai_jupyternaut/jupyternaut/__init__.py b/jupyter_ai_jupyternaut/jupyternaut/__init__.py new file mode 100644 index 0000000..1e0f40f --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/__init__.py @@ -0,0 +1 @@ +from .jupyternaut import JupyternautPersona diff --git a/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py b/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py new file mode 100644 index 0000000..7a49a89 --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py @@ -0,0 +1,107 @@ +from typing import Any, Optional + +from jupyterlab_chat.models import Message +from litellm import acompletion + +from jupyter_ai_persona_manager import BasePersona, PersonaDefaults +from jupyter_ai_persona_manager.persona_manager import SYSTEM_USERNAME +from .prompt_template import ( + JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, + JupyternautSystemPromptArgs, +) + + +class JupyternautPersona(BasePersona): + """ + The Jupyternaut persona, the main persona provided by Jupyter AI. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def defaults(self): + return PersonaDefaults( + name="Jupyternaut", + avatar_path="/api/jupyternaut/static/jupyternaut.svg", + description="The standard agent provided by JupyterLab. Currently has no tools.", + system_prompt="...", + ) + + async def process_message(self, message: Message) -> None: + if not hasattr(self, 'config_manager'): + self.send_message( + "Jupyternaut requires the `jupyter_ai_jupyternaut` server extension package.\n\n", + "Please make sure to first install that package in your environment & restart the server." + ) + if not self.config_manager.chat_model: + self.send_message( + "No chat model is configured.\n\n" + "You must set one first in the Jupyter AI settings, found in 'Settings > AI Settings' from the menu bar." + ) + return + + model_id = self.config_manager.chat_model + model_args = self.config_manager.chat_model_args + context_as_messages = self.get_context_as_messages(model_id, message) + response_aiter = await acompletion( + **model_args, + model=model_id, + messages=[ + *context_as_messages, + { + "role": "user", + "content": message.body, + }, + ], + stream=True, + ) + + await self.stream_message(response_aiter) + + def get_context_as_messages( + self, model_id: str, message: Message + ) -> list[dict[str, Any]]: + """ + Returns the current context, including attachments and recent messages, + as a list of messages accepted by `litellm.acompletion()`. + """ + system_msg_args = JupyternautSystemPromptArgs( + model_id=model_id, + persona_name=self.name, + context=self.process_attachments(message), + ).model_dump() + + system_msg = { + "role": "system", + "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args), + } + + context_as_messages = [system_msg, *self._get_history_as_messages()] + return context_as_messages + + def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]: + """ + Returns the current history as a list of messages accepted by + `litellm.acompletion()`. + """ + # TODO: consider bounding history based on message size (e.g. total + # char/token count) instead of message count. + all_messages = self.ychat.get_messages() + + # gather last k * 2 messages and return + # we exclude the last message since that is the human message just + # submitted by a user. + start_idx = 0 if k is None else -2 * k - 1 + recent_messages: list[Message] = all_messages[start_idx:-1] + + history: list[dict[str, Any]] = [] + for msg in recent_messages: + role = ( + "assistant" + if msg.sender.startswith("jupyter-ai-personas::") + else "system" if msg.sender == SYSTEM_USERNAME else "user" + ) + history.append({"role": role, "content": msg.body}) + + return history diff --git a/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py b/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py new file mode 100644 index 0000000..05cb7b9 --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py @@ -0,0 +1,55 @@ +from typing import Optional + +from jinja2 import Template +from pydantic import BaseModel + +_JUPYTERNAUT_SYSTEM_PROMPT_FORMAT = """ + + +You are {{persona_name}}, an AI agent provided in JupyterLab through the 'Jupyter AI' extension. + +Jupyter AI is an installable software package listed on PyPI and Conda Forge as `jupyter-ai`. + +When installed, Jupyter AI adds a chat experience in JupyterLab that allows multiple users to collaborate with one or more agents like yourself. + +You are not a language model, but rather an AI agent powered by a foundation model `{{model_id}}`. + +You are receiving a request from a user in JupyterLab. Your goal is to fulfill this request to the best of your ability. + +If you do not know the answer to a question, answer truthfully by responding that you do not know. + +You should use Markdown to format your response. + +Any code in your response must be enclosed in Markdown fenced code blocks (with triple backticks before and after). + +Any mathematical notation in your response must be expressed in LaTeX markup and enclosed in LaTeX delimiters. + +- Example of a correct response: The area of a circle is \\(\\pi * r^2\\). + +All dollar quantities (of USD) must be formatted in LaTeX, with the `$` symbol escaped by a single backslash `\\`. + +- Example of a correct response: `You have \\(\\$80\\) remaining.` + +You will receive any provided context and a relevant portion of the chat history. + +The user's request is located at the last message. Please fulfill the user's request to the best of your ability. + + + +{% if context %}The user has shared the following context: + +{{context}} +{% else %}The user did not share any additional context.{% endif %} + +""".strip() + + +JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE: Template = Template( + _JUPYTERNAUT_SYSTEM_PROMPT_FORMAT +) + + +class JupyternautSystemPromptArgs(BaseModel): + persona_name: str + model_id: str + context: Optional[str] = None diff --git a/jupyter_ai_jupyternaut/static/jupyternaut.svg b/jupyter_ai_jupyternaut/static/jupyternaut.svg new file mode 100644 index 0000000..dd800d5 --- /dev/null +++ b/jupyter_ai_jupyternaut/static/jupyternaut.svg @@ -0,0 +1,9 @@ + + + + + + diff --git a/pyproject.toml b/pyproject.toml index fa66fcd..c4898d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "litellm>=1.73,<2", "jinja2>=3.0,<4", "python_dotenv>=1,<2", + "jupyter_ai_persona_manager>=0.0.1", ] dynamic = ["version", "description", "authors", "urls", "keywords"] @@ -103,3 +104,13 @@ before-build-python = ["jlpm clean:all"] [tool.check-wheel-contents] ignore = ["W002"] + +############################################################################### +# Provide Jupyternaut on the personas entry point to `jupyter_ai_persona_manager`. +# This adds Jupyternaut to JupyterLab. +# See: https://jupyter-ai.readthedocs.io/en/v3/developers/entry_points_api/personas_group.html +# See also: https://packaging.python.org/en/latest/specifications/entry-points/ + +[project.entry-points."jupyter_ai.personas"] +jupyternaut = "jupyter_ai_jupyternaut.jupyternaut.jupyternaut:JupyternautPersona" +############################################################################### diff --git a/src/index.ts b/src/index.ts index 66fbe80..0732f77 100644 --- a/src/index.ts +++ b/src/index.ts @@ -16,7 +16,7 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { SingletonLayout, Widget } from '@lumino/widgets'; import { StopButton } from './components/message-footer/stop-button'; -import { completionPlugin } from './completions'; +//import { completionPlugin } from './completions'; import { buildErrorWidget } from './widgets/chat-error'; import { buildAiSettings } from './widgets/settings-widget'; import { statusItemPlugin } from './status'; @@ -145,6 +145,6 @@ export default [ jupyternautSettingsPlugin, // webComponentsPlugin, stopButtonPlugin, - completionPlugin, + // completionPlugin, statusItemPlugin ];