Skip to content

Commit 6406f9b

Browse files
committed
Client needs to be session-specific
1 parent 5b89762 commit 6406f9b

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

pkg-py/src/querychat/_querychat.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,11 @@ def __init__(
136136
prompt_template=prompt_template,
137137
)
138138

139+
# Fork and empty chat now so the per-session forks are fast
139140
client = normalize_client(client)
140-
self.client = copy.deepcopy(client)
141-
self.client.set_turns([])
142-
self.client.system_prompt = prompt
141+
self._client = copy.deepcopy(client)
142+
self._client.set_turns([])
143+
self._client.system_prompt = prompt
143144

144145
# Populated when ._server() gets called (in an active session)
145146
self._server_values: ModServerResult | None = None
@@ -313,7 +314,7 @@ def _server(self, *, enable_bookmarking: bool = False) -> None:
313314
self.id,
314315
data_source=self._data_source,
315316
greeting=self.greeting,
316-
client=self.client,
317+
client=self._client,
317318
enable_bookmarking=enable_bookmarking,
318319
)
319320

@@ -444,11 +445,27 @@ def generate_greeting(self, *, echo: Literal["none", "output"] = "none"):
444445
The greeting string (in Markdown format).
445446
446447
"""
447-
client = copy.deepcopy(self.client)
448+
client = copy.deepcopy(self._client)
448449
client.set_turns([])
449450
prompt = "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list."
450451
return str(client.chat(prompt, echo=echo))
451452

453+
@property
454+
def client(self):
455+
"""
456+
Get the (session-specific) chat client.
457+
458+
Returns
459+
-------
460+
:
461+
The current chat client.
462+
463+
"""
464+
vals = self._server_values
465+
if vals is None:
466+
raise RuntimeError("Must call .server() before accessing .client")
467+
return vals.client
468+
452469
@property
453470
def data_source(self):
454471
"""

pkg-py/src/querychat/_querychat_module.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import copy
34
from dataclasses import dataclass
45
from pathlib import Path
56
from typing import TYPE_CHECKING, Callable, Union
@@ -47,6 +48,7 @@ class ModServerResult:
4748
df: Callable[[], pd.DataFrame]
4849
sql: ReactiveString
4950
title: ReactiveStringOrNone
51+
client: chatlas.Chat
5052

5153

5254
@module.server
@@ -65,15 +67,18 @@ def mod_server(
6567
title = ReactiveStringOrNone(None)
6668
has_greeted = reactive.value[bool](False) # noqa: FBT003
6769

70+
# Set up the chat object for this session
71+
chat = copy.deepcopy(client)
72+
6873
# Create the tool functions
6974
update_dashboard_tool = tool_update_dashboard(data_source, sql, title)
7075
reset_dashboard_tool = tool_reset_dashboard(sql, title)
7176
query_tool = tool_query(data_source)
7277

7378
# Register tools with annotations for the UI
74-
client.register_tool(update_dashboard_tool)
75-
client.register_tool(query_tool)
76-
client.register_tool(reset_dashboard_tool)
79+
chat.register_tool(update_dashboard_tool)
80+
chat.register_tool(query_tool)
81+
chat.register_tool(reset_dashboard_tool)
7782

7883
# Execute query when SQL changes
7984
@reactive.calc
@@ -89,7 +94,7 @@ def filtered_df():
8994
# Handle user input
9095
@chat_ui.on_user_submit
9196
async def _(user_input: str):
92-
stream = await client.stream_async(user_input, echo="none", content="all")
97+
stream = await chat.stream_async(user_input, echo="none", content="all")
9398
await chat_ui.append_message_stream(stream)
9499

95100
@reactive.effect
@@ -100,7 +105,7 @@ async def greet_on_startup():
100105
if greeting:
101106
await chat_ui.append_message(greeting)
102107
elif greeting is None:
103-
stream = await client.stream_async(
108+
stream = await chat.stream_async(
104109
"Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list.",
105110
echo="none",
106111
)
@@ -145,4 +150,4 @@ def _on_restore(x: RestoreState) -> None:
145150
if "querychat_has_greeted" in vals:
146151
has_greeted.set(vals["querychat_has_greeted"])
147152

148-
return ModServerResult(df=filtered_df, sql=sql, title=title)
153+
return ModServerResult(df=filtered_df, sql=sql, title=title, client=chat)

0 commit comments

Comments
 (0)