11from __future__ import annotations
22
3+ import copy
34from dataclasses import dataclass
45from pathlib import Path
56from typing import TYPE_CHECKING , Callable , Union
@@ -47,6 +48,7 @@ class ModServerResult:
4748 df : Callable [[], pd .DataFrame ]
4849 sql : ReactiveString
4950 title : ReactiveStringOrNone
51+ client : chatlas .Chat
5052
5153
5254@module .server
@@ -65,15 +67,18 @@ def mod_server(
6567 title = ReactiveStringOrNone (None )
6668 has_greeted = reactive .value [bool ](False ) # noqa: FBT003
6769
70+ # Set up the chat object for this session
71+ chat = copy .deepcopy (client )
72+
6873 # Create the tool functions
6974 update_dashboard_tool = tool_update_dashboard (data_source , sql , title )
7075 reset_dashboard_tool = tool_reset_dashboard (sql , title )
7176 query_tool = tool_query (data_source )
7277
7378 # Register tools with annotations for the UI
74- client .register_tool (update_dashboard_tool )
75- client .register_tool (query_tool )
76- client .register_tool (reset_dashboard_tool )
79+ chat .register_tool (update_dashboard_tool )
80+ chat .register_tool (query_tool )
81+ chat .register_tool (reset_dashboard_tool )
7782
7883 # Execute query when SQL changes
7984 @reactive .calc
@@ -89,7 +94,7 @@ def filtered_df():
8994 # Handle user input
9095 @chat_ui .on_user_submit
9196 async def _ (user_input : str ):
92- stream = await client .stream_async (user_input , echo = "none" , content = "all" )
97+ stream = await chat .stream_async (user_input , echo = "none" , content = "all" )
9398 await chat_ui .append_message_stream (stream )
9499
95100 @reactive .effect
@@ -100,7 +105,7 @@ async def greet_on_startup():
100105 if greeting :
101106 await chat_ui .append_message (greeting )
102107 elif greeting is None :
103- stream = await client .stream_async (
108+ stream = await chat .stream_async (
104109 "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list." ,
105110 echo = "none" ,
106111 )
@@ -145,4 +150,4 @@ def _on_restore(x: RestoreState) -> None:
145150 if "querychat_has_greeted" in vals :
146151 has_greeted .set (vals ["querychat_has_greeted" ])
147152
148- return ModServerResult (df = filtered_df , sql = sql , title = title )
153+ return ModServerResult (df = filtered_df , sql = sql , title = title , client = chat )
0 commit comments