diff --git a/CHANGELOG.md b/CHANGELOG.md index c3bcc2a99..2c2bf330c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,9 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Bug fixes -* Fixed a bug with `Chat()` sometimes silently dropping errors. (#1672) - -* Fixed a bug with `Chat()` sometimes not removing it's loading icon (on error or a `None` transform). (#1679) +* A few fixes for `ui.Chat()`, including: + * Fixed a bug with `Chat()` sometimes silently dropping errors. (#1672) + * Fixed a bug with `Chat()` sometimes not removing it's loading icon (on error or a `None` transform). (#1679) + * `.messages(format="anthropic")` correctly removes non-user starting messages (once again). (#1685) * `shiny create` now uses the template `id` rather than the directory name as the default directory. (#1666) diff --git a/shiny/templates/chat/production/anthropic/app.py b/shiny/templates/chat/production/anthropic/app.py index cac452a32..f563c8809 100644 --- a/shiny/templates/chat/production/anthropic/app.py +++ b/shiny/templates/chat/production/anthropic/app.py @@ -52,8 +52,14 @@ @chat.on_user_submit async def _(): - messages = chat.messages(format="openai", token_limits=MODEL_INFO["token_limits"]) - response = await llm.chat.completions.create( - model=MODEL_INFO["name"], messages=messages, stream=True + messages = chat.messages( + format="anthropic", + token_limits=MODEL_INFO["token_limits"], + ) + response = await llm.messages.create( + model=MODEL_INFO["name"], + messages=messages, + stream=True, + max_tokens=MODEL_INFO["token_limits"][1], ) await chat.append_message_stream(response) diff --git a/shiny/ui/_chat.py b/shiny/ui/_chat.py index ddb12fd2b..d1f0c6104 100644 --- a/shiny/ui/_chat.py +++ b/shiny/ui/_chat.py @@ -471,6 +471,11 @@ def messages( """ messages = self._messages() + + # Anthropic requires a user message first and no system messages + if format == "anthropic": + messages = self._trim_anthropic_messages(messages) + if token_limits is not None: messages = self._trim_messages(messages, token_limits, format) @@ -868,17 +873,6 @@ def _trim_messages( messages2.append(m) n_other_messages2 += 1 - # Anthropic doesn't support `role: system` and requires a user message to come 1st - if format == "anthropic": - if n_system_messages > 0: - raise ValueError( - "Anthropic requires a system prompt to be specified in it's `.create()` method " - "(not in the chat messages with `role: system`)." - ) - while n_other_messages2 > 0 and messages2[-1]["role"] != "user": - messages2.pop() - n_other_messages2 -= 1 - messages2.reverse() if len(messages2) == n_system_messages and n_other_messages2 > 0: @@ -890,6 +884,22 @@ def _trim_messages( return tuple(messages2) + def _trim_anthropic_messages( + self, + messages: tuple[TransformedMessage, ...], + ) -> tuple[TransformedMessage, ...]: + + if any(m["role"] == "system" for m in messages): + raise ValueError( + "Anthropic requires a system prompt to be specified in it's `.create()` method " + "(not in the chat messages with `role: system`)." + ) + for i, m in enumerate(messages): + if m["role"] == "user": + return messages[i:] + + return () + def _get_token_count( self, content: str,