Skip to content

Commit 1311989

Browse files
committed
add jai-tool-call web component to show tool calls & outputs
1 parent 056676e commit 1311989

File tree

8 files changed

+278
-18
lines changed

8 files changed

+278
-18
lines changed

packages/jupyter-ai/jupyter_ai/personas/base_persona.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def get_tools(self, model_id: str) -> list[dict]:
518518
return tool_descriptions
519519

520520

521-
async def run_tools(self, tool_call_list: ToolCallList) -> list[ToolCallOutput]:
521+
async def run_tools(self, tool_call_list: ToolCallList) -> list[dict]:
522522
"""
523523
Runs the tools specified in a given tool call list using the default
524524
toolkit.

packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from jupyterlab_chat.models import Message
66
from litellm import acompletion
77

8+
from ...litellm_utils import StreamResult, ToolCallOutput
89
from ..base_persona import BasePersona, PersonaDefaults
910
from ..persona_manager import SYSTEM_USERNAME
1011
from .prompt_template import (
@@ -69,30 +70,78 @@ async def process_message(self, message: Message) -> None:
6970
"tool_calls": tool_calls_json
7071
})
7172

72-
# Show tool call requests to YChat (not synced with `messages`)
73-
if len(tool_calls_json):
74-
self.ychat.update_message(Message(
75-
id=result.id,
76-
body=f"\n\n```\n{json.dumps(tool_calls_json, indent=2)}\n```\n",
77-
sender=self.id,
78-
time=time.time(),
79-
raw_time=False
80-
), append=True)
73+
# Render tool calls in new message
74+
if len(result.tool_call_list):
75+
self.render_tool_calls(result)
8176

8277
# Run tools and append outputs to `messages`
8378
tool_call_outputs = await self.run_tools(result.tool_call_list)
8479
messages.extend(tool_call_outputs)
8580

86-
# Add tool call outputs to YChat (not synced with `messages`)
81+
# Render tool call outputs in new message
8782
if tool_call_outputs:
88-
self.ychat.update_message(Message(
89-
id=result.id,
90-
body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n",
91-
sender=self.id,
92-
time=time.time(),
93-
raw_time=False
94-
), append=True)
83+
self.render_tool_call_outputs(
84+
message_id=result.id,
85+
tool_call_outputs=tool_call_outputs
86+
)
9587

88+
def render_tool_calls(self, stream_result: StreamResult):
89+
"""
90+
Renders tool calls by appending the tool calls to a message.
91+
"""
92+
message_id = stream_result.id
93+
tool_call_list = stream_result.tool_call_list
94+
95+
for tool_call in tool_call_list.resolve():
96+
id = tool_call.id
97+
index = tool_call.index
98+
type_val = tool_call.type
99+
function = tool_call.function.model_dump_json()
100+
# We have to HTML-escape double quotes in the JSON string.
101+
function = function.replace('"', """)
102+
103+
self.ychat.update_message(Message(
104+
id=message_id,
105+
body=f'\n\n<jai-tool-call id="{id}" type="{type_val}" index={index} function="{function}"></jai-tool-call>\n',
106+
sender=self.id,
107+
time=time.time(),
108+
raw_time=False
109+
), append=True)
110+
111+
112+
def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict]):
113+
# TODO
114+
# self.ychat.update_message(Message(
115+
# id=message_id,
116+
# body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n",
117+
# sender=self.id,
118+
# time=time.time(),
119+
# raw_time=False
120+
# ), append=True)
121+
122+
# Updates the content of the last message directly
123+
message = self.ychat.get_message(message_id)
124+
body = message.body
125+
for output in tool_call_outputs:
126+
if not output['content']:
127+
output['content'] = ""
128+
output = ToolCallOutput(**output)
129+
tool_id = output.tool_call_id
130+
tool_output = output.model_dump_json()
131+
tool_output = tool_output.replace('"', '&quot;')
132+
body = body.replace(
133+
f'<jai-tool-call id="{tool_id}"',
134+
f'<jai-tool-call id="{tool_id}" output="{tool_output}"',
135+
)
136+
137+
self.log.error(body)
138+
self.ychat.update_message(Message(
139+
id=message.id,
140+
time=time.time(),
141+
body=body,
142+
sender=self.id,
143+
raw_time=False
144+
))
96145

97146

98147
def get_context_as_messages(

packages/jupyter-ai/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
"@lumino/widgets": "^2.3.2",
8080
"@mui/icons-material": "^5.11.0",
8181
"@mui/material": "^5.11.0",
82+
"@r2wc/react-to-web-component": "^2.0.4",
8283
"react": "^18.2.0",
8384
"react-dom": "^18.2.0"
8485
},

packages/jupyter-ai/src/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import { completionPlugin } from './completions';
2020
import { StopButton } from './components/message-footer/stop-button';
2121
import { statusItemPlugin } from './status';
2222
import { IJaiCompletionProvider } from './tokens';
23+
import { webComponentsPlugin } from './web-components';
2324
import { buildErrorWidget } from './widgets/chat-error';
2425
import { buildAiSettings } from './widgets/settings-widget';
2526

@@ -125,6 +126,7 @@ export default [
125126
plugin,
126127
statusItemPlugin,
127128
completionPlugin,
129+
webComponentsPlugin,
128130
stopStreaming,
129131
...chatCommandPlugins
130132
];
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export * from './plugin';
2+
export * from './jai-tool-call';
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import React, { useState, useMemo } from 'react';
2+
import {
3+
Box,
4+
Typography,
5+
Collapse,
6+
IconButton,
7+
CircularProgress
8+
} from '@mui/material';
9+
import ExpandMore from '@mui/icons-material/ExpandMore';
10+
import CheckCircle from '@mui/icons-material/CheckCircle';
11+
12+
type JaiToolCallProps = {
13+
id: string;
14+
type: string;
15+
function: {
16+
name: string;
17+
arguments: Record<string, any>;
18+
};
19+
index: number;
20+
output?: {
21+
tool_call_id: string;
22+
role: string;
23+
name: string;
24+
content: string | null;
25+
};
26+
};
27+
28+
export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null {
29+
const [expanded, setExpanded] = useState(false);
30+
console.log({
31+
output: props.output
32+
});
33+
const toolComplete = !!(props.output && Object.keys(props.output).length > 0);
34+
const hasOutput = !!(toolComplete && props.output?.content?.length);
35+
36+
const handleExpandClick = () => {
37+
setExpanded(!expanded);
38+
};
39+
40+
const statusIcon: JSX.Element = toolComplete ? (
41+
<CheckCircle sx={{ color: 'green', fontSize: 16 }} />
42+
) : (
43+
<CircularProgress size={16} />
44+
);
45+
46+
const statusText: JSX.Element = (
47+
<Typography variant="caption">
48+
{toolComplete ? 'Ran' : 'Running'}{' '}
49+
<Typography variant="caption" sx={{ fontWeight: 'bold' }}>
50+
{props.function.name}
51+
</Typography>{' '}
52+
tool
53+
{toolComplete ? '.' : '...'}
54+
</Typography>
55+
);
56+
57+
const toolArgsJson = useMemo(
58+
() => JSON.stringify(props.function.arguments, null, 2),
59+
[props.function.arguments]
60+
);
61+
62+
const toolArgsSection: JSX.Element | null =
63+
toolArgsJson === '{}' ? null : (
64+
<Box>
65+
<Typography variant="caption" sx={{ fontWeight: 'bold' }}>
66+
Tool arguments
67+
</Typography>
68+
<pre style={{ marginBottom: toolComplete ? 8 : 'unset' }}>
69+
{toolArgsJson}
70+
</pre>
71+
</Box>
72+
);
73+
74+
const toolOutputSection: JSX.Element | null = hasOutput ? (
75+
<Box>
76+
<Typography variant="caption" sx={{ fontWeight: 'bold' }}>
77+
Tool output
78+
</Typography>
79+
<pre>{props.output?.content}</pre>
80+
</Box>
81+
) : null;
82+
83+
if (!props.id || !props.type || !props.function) {
84+
return null;
85+
}
86+
87+
return (
88+
<Box
89+
sx={{
90+
border: '1px solid #e0e0e0',
91+
borderRadius: 1,
92+
p: 1,
93+
mb: 1
94+
}}
95+
>
96+
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
97+
{statusIcon}
98+
{statusText}
99+
100+
<IconButton
101+
onClick={handleExpandClick}
102+
size="small"
103+
sx={{
104+
transform: expanded ? 'rotate(180deg)' : 'rotate(0deg)',
105+
transition: 'transform 0.3s',
106+
borderRadius: 'unset'
107+
}}
108+
>
109+
<ExpandMore />
110+
</IconButton>
111+
</Box>
112+
113+
<Collapse in={expanded}>
114+
<Box sx={{ mt: 1, pt: 1, borderTop: '1px solid #f0f0f0' }}>
115+
{toolArgsSection}
116+
{toolOutputSection}
117+
</Box>
118+
</Collapse>
119+
</Box>
120+
);
121+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import {
2+
JupyterFrontEnd,
3+
JupyterFrontEndPlugin
4+
} from '@jupyterlab/application';
5+
import r2wc from '@r2wc/react-to-web-component';
6+
7+
import { JaiToolCall } from './jai-tool-call';
8+
import { ISanitizer, Sanitizer } from '@jupyterlab/apputils';
9+
import { IRenderMime } from '@jupyterlab/rendermime';
10+
11+
/**
12+
* Plugin that registers custom web components for usage in AI responses.
13+
*/
14+
export const webComponentsPlugin: JupyterFrontEndPlugin<IRenderMime.ISanitizer> =
15+
{
16+
id: '@jupyter-ai/core:web-components',
17+
autoStart: true,
18+
provides: ISanitizer,
19+
activate: (app: JupyterFrontEnd) => {
20+
// Define the JaiToolCall web component
21+
// ['id', 'type', 'function', 'index', 'output']
22+
const JaiToolCallWebComponent = r2wc(JaiToolCall, {
23+
props: {
24+
id: 'string',
25+
type: 'string',
26+
function: 'json',
27+
index: 'number',
28+
output: 'json'
29+
}
30+
});
31+
32+
// Register the web component
33+
customElements.define('jai-tool-call', JaiToolCallWebComponent);
34+
console.log("Registered custom 'jai-tool-call' web component.");
35+
36+
// Finally, override the default Rendermime sanitizer to allow custom web
37+
// components in the output.
38+
class CustomSanitizer
39+
extends Sanitizer
40+
implements IRenderMime.ISanitizer
41+
{
42+
sanitize(
43+
dirty: string,
44+
customOptions: IRenderMime.ISanitizerOptions
45+
): string {
46+
const options: IRenderMime.ISanitizerOptions = {
47+
// default sanitizer options
48+
...(this as any)._options,
49+
// custom sanitizer options (variable per call)
50+
...customOptions
51+
};
52+
53+
return super.sanitize(dirty, {
54+
...options,
55+
allowedTags: [...(options?.allowedTags ?? []), 'jai-tool-call'],
56+
allowedAttributes: {
57+
...options?.allowedAttributes,
58+
'jai-tool-call': ['id', 'type', 'function', 'index', 'output']
59+
}
60+
});
61+
}
62+
}
63+
return new CustomSanitizer();
64+
}
65+
};

yarn.lock

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,6 +2256,7 @@ __metadata:
22562256
"@lumino/widgets": ^2.3.2
22572257
"@mui/icons-material": ^5.11.0
22582258
"@mui/material": ^5.11.0
2259+
"@r2wc/react-to-web-component": ^2.0.4
22592260
"@stylistic/eslint-plugin": ^3.0.1
22602261
"@types/jest": ^29
22612262
"@types/react-dom": ^18.2.0
@@ -4826,6 +4827,25 @@ __metadata:
48264827
languageName: node
48274828
linkType: hard
48284829

4830+
"@r2wc/core@npm:^1.0.0":
4831+
version: 1.2.0
4832+
resolution: "@r2wc/core@npm:1.2.0"
4833+
checksum: e0dc23e8fd1f0d96193b67f5eb04b74b25b9f4609778e6ea2427c565eb590f458553cad307a2fdb3fc4614f6a576d7701b9bacf11775958bc560cc3b3b5aaae7
4834+
languageName: node
4835+
linkType: hard
4836+
4837+
"@r2wc/react-to-web-component@npm:^2.0.4":
4838+
version: 2.0.4
4839+
resolution: "@r2wc/react-to-web-component@npm:2.0.4"
4840+
dependencies:
4841+
"@r2wc/core": ^1.0.0
4842+
peerDependencies:
4843+
react: ^18.0.0 || ^19.0.0
4844+
react-dom: ^18.0.0 || ^19.0.0
4845+
checksum: 7b140ffd612173a30d74717d18efcf554774ef0ed0fe72f207ec21df707685ef5f4c34521e6840041665550c6461171dc32f12835f35beb1788ccac0c66c0e5c
4846+
languageName: node
4847+
linkType: hard
4848+
48294849
"@rjsf/core@npm:^5.13.4":
48304850
version: 5.17.0
48314851
resolution: "@rjsf/core@npm:5.17.0"

0 commit comments

Comments
 (0)