diff --git a/.gitignore b/.gitignore
index a270947..8f9f5d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -123,3 +123,6 @@ dmypy.json
# Yarn cache
.yarn/
+
+# For local testing
+playground/
\ No newline at end of file
diff --git a/jupyter_ai_jupyternaut/extension_app.py b/jupyter_ai_jupyternaut/extension_app.py
index 07952d1..d7b8f7a 100644
--- a/jupyter_ai_jupyternaut/extension_app.py
+++ b/jupyter_ai_jupyternaut/extension_app.py
@@ -2,22 +2,30 @@
from asyncio import get_event_loop_policy
from jupyter_server.extension.application import ExtensionApp
from jupyter_server.serverapp import ServerApp
+import os
+from tornado.web import StaticFileHandler
from traitlets import List, Unicode, Dict
from traitlets.config import Config
from typing import TYPE_CHECKING
from .config import ConfigManager, ConfigRestAPI
from .handlers import RouteHandler
+from .jupyternaut import JupyternautPersona
from .models import ChatModelsRestAPI, ModelParametersRestAPI
from .secrets import EnvSecretsManager, SecretsRestAPI
if TYPE_CHECKING:
from asyncio import AbstractEventLoop
+JUPYTERNAUT_AVATAR_PATH = str(
+ os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg")
+)
+
+
class JupyternautExtension(ExtensionApp):
"""
The Jupyternaut server extension.
-
+
This serves several REST APIs under the `/api/jupyternaut` route. Currently,
for the sake of simplicity, they are hard-coded into the Jupyternaut server
extension to allow users to configure the chat model & add API keys.
@@ -33,6 +41,11 @@ class JupyternautExtension(ExtensionApp):
(r"api/jupyternaut/models/chat/?", ChatModelsRestAPI),
(r"api/jupyternaut/model-parameters/?", ModelParametersRestAPI),
(r"api/jupyternaut/secrets/?", SecretsRestAPI),
+ (
+ r"api/jupyternaut/static/jupyternaut.svg()/?",
+ StaticFileHandler,
+ {"path": JUPYTERNAUT_AVATAR_PATH},
+ ),
]
allowed_providers = List(
@@ -176,7 +189,7 @@ def initialize_settings(self):
}
# Initialize ConfigManager
- self.settings["jupyternaut.config_manager"] = ConfigManager(
+ config_manager = ConfigManager(
config=self.config,
log=self.log,
allowed_providers=self.allowed_providers,
@@ -186,9 +199,16 @@ def initialize_settings(self):
defaults=defaults,
)
- # Initialize SecretsManager
+ # Bind ConfigManager instance to global settings dictionary
+ self.settings["jupyternaut.config_manager"] = config_manager
+
+ # Bind ConfigManager instance to Jupyternaut as a class variable
+ JupyternautPersona.config_manager = config_manager
+
+ # Initialize SecretsManager and bind it to global settings dictionary
self.settings["jupyternaut.secrets_manager"] = EnvSecretsManager(parent=self)
-
+
+
def _link_jupyter_server_extension(self, server_app: ServerApp):
"""Setup custom config needed by this extension."""
c = Config()
@@ -210,4 +230,3 @@ def _link_jupyter_server_extension(self, server_app: ServerApp):
]
server_app.update_config(c)
super()._link_jupyter_server_extension(server_app)
-
diff --git a/jupyter_ai_jupyternaut/jupyternaut/__init__.py b/jupyter_ai_jupyternaut/jupyternaut/__init__.py
new file mode 100644
index 0000000..1e0f40f
--- /dev/null
+++ b/jupyter_ai_jupyternaut/jupyternaut/__init__.py
@@ -0,0 +1 @@
+from .jupyternaut import JupyternautPersona
diff --git a/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py b/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py
new file mode 100644
index 0000000..7a49a89
--- /dev/null
+++ b/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py
@@ -0,0 +1,107 @@
+from typing import Any, Optional
+
+from jupyterlab_chat.models import Message
+from litellm import acompletion
+
+from jupyter_ai_persona_manager import BasePersona, PersonaDefaults
+from jupyter_ai_persona_manager.persona_manager import SYSTEM_USERNAME
+from .prompt_template import (
+ JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE,
+ JupyternautSystemPromptArgs,
+)
+
+
+class JupyternautPersona(BasePersona):
+ """
+ The Jupyternaut persona, the main persona provided by Jupyter AI.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ @property
+ def defaults(self):
+ return PersonaDefaults(
+ name="Jupyternaut",
+ avatar_path="/api/jupyternaut/static/jupyternaut.svg",
+ description="The standard agent provided by JupyterLab. Currently has no tools.",
+ system_prompt="...",
+ )
+
+ async def process_message(self, message: Message) -> None:
+ if not hasattr(self, 'config_manager'):
+ self.send_message(
+ "Jupyternaut requires the `jupyter_ai_jupyternaut` server extension package.\n\n",
+ "Please make sure to first install that package in your environment & restart the server."
+ )
+ if not self.config_manager.chat_model:
+ self.send_message(
+ "No chat model is configured.\n\n"
+ "You must set one first in the Jupyter AI settings, found in 'Settings > AI Settings' from the menu bar."
+ )
+ return
+
+ model_id = self.config_manager.chat_model
+ model_args = self.config_manager.chat_model_args
+ context_as_messages = self.get_context_as_messages(model_id, message)
+ response_aiter = await acompletion(
+ **model_args,
+ model=model_id,
+ messages=[
+ *context_as_messages,
+ {
+ "role": "user",
+ "content": message.body,
+ },
+ ],
+ stream=True,
+ )
+
+ await self.stream_message(response_aiter)
+
+ def get_context_as_messages(
+ self, model_id: str, message: Message
+ ) -> list[dict[str, Any]]:
+ """
+ Returns the current context, including attachments and recent messages,
+ as a list of messages accepted by `litellm.acompletion()`.
+ """
+ system_msg_args = JupyternautSystemPromptArgs(
+ model_id=model_id,
+ persona_name=self.name,
+ context=self.process_attachments(message),
+ ).model_dump()
+
+ system_msg = {
+ "role": "system",
+ "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args),
+ }
+
+ context_as_messages = [system_msg, *self._get_history_as_messages()]
+ return context_as_messages
+
+ def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]:
+ """
+ Returns the current history as a list of messages accepted by
+ `litellm.acompletion()`.
+ """
+ # TODO: consider bounding history based on message size (e.g. total
+ # char/token count) instead of message count.
+ all_messages = self.ychat.get_messages()
+
+ # gather last k * 2 messages and return
+ # we exclude the last message since that is the human message just
+ # submitted by a user.
+ start_idx = 0 if k is None else -2 * k - 1
+ recent_messages: list[Message] = all_messages[start_idx:-1]
+
+ history: list[dict[str, Any]] = []
+ for msg in recent_messages:
+ role = (
+ "assistant"
+ if msg.sender.startswith("jupyter-ai-personas::")
+ else "system" if msg.sender == SYSTEM_USERNAME else "user"
+ )
+ history.append({"role": role, "content": msg.body})
+
+ return history
diff --git a/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py b/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py
new file mode 100644
index 0000000..05cb7b9
--- /dev/null
+++ b/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py
@@ -0,0 +1,55 @@
+from typing import Optional
+
+from jinja2 import Template
+from pydantic import BaseModel
+
+_JUPYTERNAUT_SYSTEM_PROMPT_FORMAT = """
+
+
+You are {{persona_name}}, an AI agent provided in JupyterLab through the 'Jupyter AI' extension.
+
+Jupyter AI is an installable software package listed on PyPI and Conda Forge as `jupyter-ai`.
+
+When installed, Jupyter AI adds a chat experience in JupyterLab that allows multiple users to collaborate with one or more agents like yourself.
+
+You are not a language model, but rather an AI agent powered by a foundation model `{{model_id}}`.
+
+You are receiving a request from a user in JupyterLab. Your goal is to fulfill this request to the best of your ability.
+
+If you do not know the answer to a question, answer truthfully by responding that you do not know.
+
+You should use Markdown to format your response.
+
+Any code in your response must be enclosed in Markdown fenced code blocks (with triple backticks before and after).
+
+Any mathematical notation in your response must be expressed in LaTeX markup and enclosed in LaTeX delimiters.
+
+- Example of a correct response: The area of a circle is \\(\\pi * r^2\\).
+
+All dollar quantities (of USD) must be formatted in LaTeX, with the `$` symbol escaped by a single backslash `\\`.
+
+- Example of a correct response: `You have \\(\\$80\\) remaining.`
+
+You will receive any provided context and a relevant portion of the chat history.
+
+The user's request is located at the last message. Please fulfill the user's request to the best of your ability.
+
+
+
+{% if context %}The user has shared the following context:
+
+{{context}}
+{% else %}The user did not share any additional context.{% endif %}
+
+""".strip()
+
+
+JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE: Template = Template(
+ _JUPYTERNAUT_SYSTEM_PROMPT_FORMAT
+)
+
+
+class JupyternautSystemPromptArgs(BaseModel):
+ persona_name: str
+ model_id: str
+ context: Optional[str] = None
diff --git a/jupyter_ai_jupyternaut/static/jupyternaut.svg b/jupyter_ai_jupyternaut/static/jupyternaut.svg
new file mode 100644
index 0000000..dd800d5
--- /dev/null
+++ b/jupyter_ai_jupyternaut/static/jupyternaut.svg
@@ -0,0 +1,9 @@
+
diff --git a/pyproject.toml b/pyproject.toml
index fa66fcd..c4898d2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,6 +41,7 @@ dependencies = [
"litellm>=1.73,<2",
"jinja2>=3.0,<4",
"python_dotenv>=1,<2",
+ "jupyter_ai_persona_manager>=0.0.1",
]
dynamic = ["version", "description", "authors", "urls", "keywords"]
@@ -103,3 +104,13 @@ before-build-python = ["jlpm clean:all"]
[tool.check-wheel-contents]
ignore = ["W002"]
+
+###############################################################################
+# Provide Jupyternaut on the personas entry point to `jupyter_ai_persona_manager`.
+# This adds Jupyternaut to JupyterLab.
+# See: https://jupyter-ai.readthedocs.io/en/v3/developers/entry_points_api/personas_group.html
+# See also: https://packaging.python.org/en/latest/specifications/entry-points/
+
+[project.entry-points."jupyter_ai.personas"]
+jupyternaut = "jupyter_ai_jupyternaut.jupyternaut.jupyternaut:JupyternautPersona"
+###############################################################################
diff --git a/src/index.ts b/src/index.ts
index 66fbe80..0732f77 100644
--- a/src/index.ts
+++ b/src/index.ts
@@ -16,7 +16,7 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import { SingletonLayout, Widget } from '@lumino/widgets';
import { StopButton } from './components/message-footer/stop-button';
-import { completionPlugin } from './completions';
+//import { completionPlugin } from './completions';
import { buildErrorWidget } from './widgets/chat-error';
import { buildAiSettings } from './widgets/settings-widget';
import { statusItemPlugin } from './status';
@@ -145,6 +145,6 @@ export default [
jupyternautSettingsPlugin,
// webComponentsPlugin,
stopButtonPlugin,
- completionPlugin,
+ // completionPlugin,
statusItemPlugin
];