1
+ from __future__ import annotations
1
2
import asyncio
2
3
import os
3
4
from abc import ABC , ABCMeta , abstractmethod
4
5
from dataclasses import asdict
5
6
from logging import Logger
6
7
from time import time
7
- from typing import TYPE_CHECKING , Any , Optional
8
+ from typing import TYPE_CHECKING , Any , Optional , Tuple
8
9
9
10
from jupyter_ai .config_manager import ConfigManager
10
11
from jupyterlab_chat .models import Message , NewMessage , User
11
12
from jupyterlab_chat .ychat import YChat
13
+ from litellm import ModelResponseStream , supports_function_calling
14
+ from litellm .utils import function_to_dict
12
15
from pydantic import BaseModel
13
16
from traitlets import MetaHasTraits
14
17
from traitlets .config import LoggingConfigurable
15
18
16
19
from .persona_awareness import PersonaAwareness
20
+ from ..litellm_utils import ToolCallList , ResolvedToolCall
21
+
22
+ # Import toolkits
23
+ from jupyter_ai_tools .toolkits .file_system import toolkit as fs_toolkit
24
+ from jupyter_ai_tools .toolkits .code_execution import toolkit as codeexec_toolkit
25
+ from jupyter_ai_tools .toolkits .git import toolkit as git_toolkit
17
26
18
- # prevents a circular import
19
- # types imported under this block have to be surrounded in single quotes on use
20
27
if TYPE_CHECKING :
21
28
from collections .abc import AsyncIterator
22
-
23
- from litellm import ModelResponseStream
24
-
25
29
from .persona_manager import PersonaManager
30
+ from ..tools import Toolkit
26
31
32
+ DEFAULT_TOOLKITS : dict [str , Toolkit ] = {
33
+ "fs" : fs_toolkit ,
34
+ "codeexec" : codeexec_toolkit ,
35
+ "git" : git_toolkit ,
36
+ }
27
37
28
38
class PersonaDefaults (BaseModel ):
29
39
"""
@@ -237,7 +247,7 @@ def as_user_dict(self) -> dict[str, Any]:
237
247
238
248
async def stream_message (
239
249
self , reply_stream : "AsyncIterator[ModelResponseStream | str]"
240
- ) -> None :
250
+ ) -> Tuple [ ResolvedToolCall , ToolCallList ] :
241
251
"""
242
252
Takes an async iterator, dubbed the 'reply stream', and streams it to a
243
253
new message by this persona in the YChat. The async iterator may yield
@@ -247,21 +257,36 @@ async def stream_message(
247
257
stream, then continuously updates it until the stream is closed.
248
258
249
259
- Automatically manages its awareness state to show writing status.
260
+
261
+ Returns a list of `ResolvedToolCall` objects. If this list is not empty,
262
+ the persona should run these tools.
250
263
"""
251
264
stream_id : Optional [str ] = None
252
265
stream_interrupted = False
253
266
try :
254
267
self .awareness .set_local_state_field ("isWriting" , True )
255
- async for chunk in reply_stream :
256
- # Coerce LiteLLM stream chunk to a string delta
257
- if not isinstance (chunk , str ):
258
- chunk = chunk .choices [0 ].delta .content
268
+ toolcall_list = ToolCallList ()
269
+ resolved_toolcalls : list [ResolvedToolCall ] = []
259
270
260
- # LiteLLM streams always terminate with an empty chunk, so we
261
- # ignore and continue when this occurs.
262
- if not chunk :
271
+ async for chunk in reply_stream :
272
+ # Compute `content_delta` and `tool_calls_delta` based on the
273
+ # type of object yielded by `reply_stream`.
274
+ if isinstance (chunk , ModelResponseStream ):
275
+ delta = chunk .choices [0 ].delta
276
+ content_delta = delta .content
277
+ toolcalls_delta = delta .tool_calls
278
+ elif isinstance (chunk , str ):
279
+ content_delta = chunk
280
+ toolcalls_delta = None
281
+ else :
282
+ raise Exception (f"Unrecognized type in stream_message(): { type (chunk )} " )
283
+
284
+ # LiteLLM streams always terminate with an empty chunk, so
285
+ # continue in this case.
286
+ if not (content_delta or toolcalls_delta ):
263
287
continue
264
288
289
+ # Terminate the stream if the user requested it.
265
290
if (
266
291
stream_id
267
292
and stream_id in self .message_interrupted .keys ()
@@ -280,34 +305,46 @@ async def stream_message(
280
305
stream_interrupted = True
281
306
break
282
307
283
- if not stream_id :
284
- stream_id = self .ychat .add_message (
285
- NewMessage (body = "" , sender = self .id )
308
+ # Append `content_delta` to the existing message.
309
+ if content_delta :
310
+ # Start the stream with an empty message on the initial reply.
311
+ # Bind the new message ID to `stream_id`.
312
+ if not stream_id :
313
+ stream_id = self .ychat .add_message (
314
+ NewMessage (body = "" , sender = self .id )
315
+ )
316
+ self .message_interrupted [stream_id ] = asyncio .Event ()
317
+ self .awareness .set_local_state_field ("isWriting" , stream_id )
318
+ assert stream_id
319
+
320
+ self .ychat .update_message (
321
+ Message (
322
+ id = stream_id ,
323
+ body = content_delta ,
324
+ time = time (),
325
+ sender = self .id ,
326
+ raw_time = False ,
327
+ ),
328
+ append = True ,
286
329
)
287
- self .message_interrupted [stream_id ] = asyncio .Event ()
288
- self .awareness .set_local_state_field ("isWriting" , stream_id )
289
-
290
- assert stream_id
291
- self .ychat .update_message (
292
- Message (
293
- id = stream_id ,
294
- body = chunk ,
295
- time = time (),
296
- sender = self .id ,
297
- raw_time = False ,
298
- ),
299
- append = True ,
300
- )
330
+ if toolcalls_delta :
331
+ toolcall_list += toolcalls_delta
332
+
333
+ # After the reply stream is complete, resolve the list of tool calls.
334
+ resolved_toolcalls = toolcall_list .resolve ()
301
335
except Exception as e :
302
336
self .log .error (
303
337
f"Persona '{ self .name } ' encountered an exception printed below when attempting to stream output."
304
338
)
305
339
self .log .exception (e )
306
340
finally :
341
+ # Reset local state
307
342
self .awareness .set_local_state_field ("isWriting" , False )
308
- if stream_id :
309
- # if stream was interrupted, add a tombstone
310
- if stream_interrupted :
343
+ self .message_interrupted .pop (stream_id , None )
344
+
345
+ # If stream was interrupted, add a tombstone and return `[]`,
346
+ # indicating that no tools should be run afterwards.
347
+ if stream_id and stream_interrupted :
311
348
stream_tombstone = "\n \n (AI response stopped by user)"
312
349
self .ychat .update_message (
313
350
Message (
@@ -319,8 +356,15 @@ async def stream_message(
319
356
),
320
357
append = True ,
321
358
)
322
- if stream_id in self .message_interrupted .keys ():
323
- del self .message_interrupted [stream_id ]
359
+ return None
360
+
361
+ # Otherwise return the resolved list.
362
+ if len (resolved_toolcalls ):
363
+ count = len (resolved_toolcalls )
364
+ names = sorted ([tc .function .name for tc in resolved_toolcalls ])
365
+ self .log .info (f"AI response triggered { count } tool calls: { names } " )
366
+ return resolved_toolcalls , toolcall_list
367
+
324
368
325
369
def send_message (self , body : str ) -> None :
326
370
"""
@@ -361,7 +405,7 @@ def get_mcp_config(self) -> dict[str, Any]:
361
405
Returns the MCP config for the current chat.
362
406
"""
363
407
return self .parent .get_mcp_config ()
364
-
408
+
365
409
def process_attachments (self , message : Message ) -> Optional [str ]:
366
410
"""
367
411
Process file attachments in the message and return their content as a string.
@@ -431,6 +475,99 @@ def resolve_attachment_to_path(self, attachment_id: str) -> Optional[str]:
431
475
self .log .error (f"Failed to resolve attachment { attachment_id } : { e } " )
432
476
return None
433
477
478
+ def get_tools (self , model_id : str ) -> list [dict ]:
479
+ """
480
+ Returns the `tools` parameter which should be passed to
481
+ `litellm.acompletion()` for a given LiteLLM model ID.
482
+
483
+ If the model does not support tool-calling, this method returns an empty
484
+ list. Otherwise, it returns the list of tools available in the current
485
+ environment. These may include:
486
+
487
+ - The default set of tool functions in Jupyter AI, defined in the
488
+ `jupyter_ai_tools` package.
489
+
490
+ - (TODO) Tools provided by MCP server configuration, if any.
491
+
492
+ - (TODO) Web search.
493
+
494
+ - (TODO) File search using vector store IDs.
495
+
496
+ TODO: cache this
497
+
498
+ TODO: Implement some permissions system so users can control what tools
499
+ are allowable.
500
+
501
+ NOTE: The returned list is expected by LiteLLM to conform to the `tools`
502
+ parameter defintiion defined by the OpenAI API:
503
+ https://platform.openai.com/docs/guides/tools#available-tools
504
+
505
+ NOTE: This API is a WIP and is very likely to change.
506
+ """
507
+ # Return early if the model does not support tool calling
508
+ if not supports_function_calling (model = model_id ):
509
+ return []
510
+
511
+ tool_descriptions = []
512
+
513
+ # Get all tools from `jupyter_ai_tools` and store their object descriptions
514
+ for toolkit_name , toolkit in DEFAULT_TOOLKITS .items ():
515
+ # TODO: make these tool permissions configurable.
516
+ for tool in toolkit .get_tools ():
517
+ # Here, we are using a util function from LiteLLM to coerce
518
+ # each `Tool` struct into a tool description dictionary expected
519
+ # by LiteLLM.
520
+ desc = {
521
+ "type" : "function" ,
522
+ "function" : function_to_dict (tool .callable ),
523
+ }
524
+
525
+ # Prepend the toolkit name to each function name, hopefully
526
+ # ensuring every tool function has a unique name.
527
+ # e.g. 'git_add' => 'git__git_add'
528
+ #
529
+ # TODO: Actually ensure this instead of hoping.
530
+ desc ['function' ]['name' ] = f"{ toolkit_name } __{ desc ['function' ]['name' ]} "
531
+ tool_descriptions .append (desc )
532
+
533
+ # Finally, return the tool descriptions
534
+ return tool_descriptions
535
+
536
+
537
+ async def run_tools (self , tools : list [ResolvedToolCall ]) -> list [dict ]:
538
+ """
539
+ Runs the tools specified in the list of tool calls returned by
540
+ `self.stream_message()`. Returns a list of dictionaries
541
+ `toolcall_outputs: list[dict]`, which should be appended directly to the
542
+ message history on the next invocation of the LLM.
543
+ """
544
+ if not len (tools ):
545
+ return []
546
+
547
+ tool_outputs : list [dict ] = []
548
+ for tool_call in tools :
549
+ # Get tool definition from the correct toolkit
550
+ toolkit_name , tool_name = tool_call .function .name .split ("__" )
551
+ assert toolkit_name in DEFAULT_TOOLKITS
552
+ tool_defn = DEFAULT_TOOLKITS [toolkit_name ].get_tool_unsafe (tool_name )
553
+
554
+ # Run tool and store its output
555
+ output = await tool_defn .callable (** tool_call .function .arguments )
556
+
557
+ # Store the tool output in a dictionary accepted by LiteLLM
558
+ output_dict = {
559
+ "tool_call_id" : tool_call .id ,
560
+ "role" : "tool" ,
561
+ "name" : tool_call .function .name ,
562
+ "content" : output ,
563
+ }
564
+ tool_outputs .append (output_dict )
565
+
566
+ self .log .info (f"Ran { len (tools )} tool functions." )
567
+ return tool_outputs
568
+
569
+
570
+
434
571
def shutdown (self ) -> None :
435
572
"""
436
573
Shuts the persona down. This method should:
0 commit comments