Skip to content

Commit a632159

Browse files
wchcpsievert
andauthored
Allow @chat.transform_assistant_response function to return None (#1641)
Co-authored-by: Carson <[email protected]>
1 parent 0f7765b commit a632159

File tree

7 files changed

+83
-17
lines changed

7 files changed

+83
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
* A handful of fixes for `ui.Chat()`, including:
2323
* A fix for use inside Shiny modules. (#1582)
2424
* `.messages(format="google")` now returns the correct role. (#1622)
25+
* `transform_assistant_response` can now return `None` and correctly handles change of content on the last chunk. (#1641)
2526

2627
* An empty `ui.input_date()` value no longer crashes Shiny. (#1528)
2728

js/chat/chat.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,10 @@ class ChatContainer extends LightElement {
275275
this.#onAppendChunk
276276
);
277277
this.addEventListener("shiny-chat-clear-messages", this.#onClear);
278-
this.addEventListener("shiny-chat-update-user-input", this.#onUpdateUserInput);
278+
this.addEventListener(
279+
"shiny-chat-update-user-input",
280+
this.#onUpdateUserInput
281+
);
279282
this.addEventListener(
280283
"shiny-chat-remove-loading-message",
281284
this.#onRemoveLoadingMessage
@@ -369,6 +372,7 @@ class ChatContainer extends LightElement {
369372

370373
if (message.chunk_type === "message_end") {
371374
lastMessage.removeAttribute("is_streaming");
375+
lastMessage.setAttribute("content", message.content);
372376
this.#finalizeMessage();
373377
return;
374378
}

shiny/ui/_chat.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@
5252
# user input content types
5353
TransformUserInput = Callable[[str], Union[str, None]]
5454
TransformUserInputAsync = Callable[[str], Awaitable[Union[str, None]]]
55-
TransformAssistantResponse = Callable[[str], Union[str, HTML]]
56-
TransformAssistantResponseAsync = Callable[[str], Awaitable[Union[str, HTML]]]
57-
TransformAssistantResponseChunk = Callable[[str, str, bool], Union[str, HTML]]
55+
TransformAssistantResponse = Callable[[str], Union[str, HTML, None]]
56+
TransformAssistantResponseAsync = Callable[[str], Awaitable[Union[str, HTML, None]]]
57+
TransformAssistantResponseChunk = Callable[[str, str, bool], Union[str, HTML, None]]
5858
TransformAssistantResponseChunkAsync = Callable[
59-
[str, str, bool], Awaitable[Union[str, HTML]]
59+
[str, str, bool], Awaitable[Union[str, HTML, None]]
6060
]
6161
TransformAssistantResponseFunction = Union[
6262
TransformAssistantResponse,
@@ -711,11 +711,11 @@ def transform_assistant_response(
711711
Parameters
712712
----------
713713
fn
714-
A function that takes a string and returns a string or
715-
:class:`shiny.ui.HTML`. If `fn` returns a string, it gets interpreted and
716-
parsed as a markdown on the client (and the resulting HTML is then
717-
sanitized). If `fn` returns :class:`shiny.ui.HTML`, it will be displayed
718-
as-is.
714+
A function that takes a string and returns either a string,
715+
:class:`shiny.ui.HTML`, or `None`. If `fn` returns a string, it gets
716+
interpreted and parsed as a markdown on the client (and the resulting HTML
717+
is then sanitized). If `fn` returns :class:`shiny.ui.HTML`, it will be
718+
displayed as-is. If `fn` returns `None`, the response is effectively ignored.
719719
720720
Note
721721
----
@@ -774,16 +774,20 @@ async def _transform_message(
774774

775775
if message["role"] == "user" and self._transform_user is not None:
776776
content = await self._transform_user(message["content"])
777-
if content is None:
778-
return None
779-
res[key] = content
780777

781778
elif message["role"] == "assistant" and self._transform_assistant is not None:
782-
res[key] = await self._transform_assistant(
779+
content = await self._transform_assistant(
783780
message["content"],
784781
chunk_content or "",
785782
chunk == "end" or chunk is False,
786783
)
784+
else:
785+
return res
786+
787+
if content is None:
788+
return None
789+
790+
res[key] = content
787791

788792
return res
789793

shiny/www/py-shiny/chat/chat.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

shiny/www/py-shiny/chat/chat.js.map

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from shiny import render
2+
from shiny.express import ui
3+
4+
chat = ui.Chat(id="chat")
5+
chat.ui()
6+
7+
8+
@chat.transform_assistant_response
9+
def transform(content: str, chunk: str, done: bool):
10+
if done:
11+
return content + "...DONE!"
12+
else:
13+
return content
14+
15+
16+
@chat.on_user_submit
17+
async def _():
18+
await chat.append_message_stream(("Simple ", "response"))
19+
20+
21+
"Message state:"
22+
23+
24+
@render.code
25+
def message_state():
26+
return str(chat.messages())
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from playwright.sync_api import Page, expect
2+
from utils.deploy_utils import skip_on_webkit
3+
4+
from shiny.playwright import controller
5+
from shiny.run import ShinyAppProc
6+
7+
8+
@skip_on_webkit
9+
def test_validate_chat_transform_assistant(page: Page, local_app: ShinyAppProc) -> None:
10+
page.goto(local_app.url)
11+
12+
chat = controller.Chat(page, "chat")
13+
message_state = controller.OutputCode(page, "message_state")
14+
15+
# Wait for app to load
16+
message_state.expect_value("()", timeout=30 * 1000)
17+
18+
expect(chat.loc).to_be_visible(timeout=30 * 1000)
19+
expect(chat.loc_input_button).to_be_disabled()
20+
21+
chat.set_user_input("foo")
22+
chat.send_user_input()
23+
chat.expect_latest_message("Simple response...DONE!", timeout=30 * 1000)
24+
25+
message_state_expected = tuple(
26+
[
27+
{"content": "foo", "role": "user"},
28+
{"content": "Simple response", "role": "assistant"},
29+
]
30+
)
31+
message_state.expect_value(str(message_state_expected))

0 commit comments

Comments
 (0)