Skip to content

Commit f86abb8

Browse files
committed
Move most of chatlas logic to _chat_bookmark.py
1 parent 3f6f65f commit f86abb8

File tree

2 files changed

+125
-86
lines changed

2 files changed

+125
-86
lines changed

shiny/ui/_chat.py

Lines changed: 9 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import importlib.util
43
import inspect
54
from contextlib import asynccontextmanager
65
from typing import (
@@ -12,13 +11,11 @@
1211
Iterable,
1312
Literal,
1413
Optional,
15-
Protocol,
1614
Sequence,
1715
Tuple,
1816
Union,
1917
cast,
2018
overload,
21-
runtime_checkable,
2219
)
2320
from weakref import WeakValueDictionary
2421

@@ -34,6 +31,13 @@
3431
from ..session import get_current_session, require_active_session, session_context
3532
from ..types import MISSING, MISSING_TYPE, Jsonifiable, NotifyException
3633
from ..ui.css import CssUnit, as_css_unit
34+
from ._chat_bookmark import (
35+
BookmarkCancelCallback,
36+
ClientWithState,
37+
get_chatlas_state,
38+
is_chatlas_chat_client,
39+
set_chatlas_state,
40+
)
3741
from ._chat_normalize import normalize_message, normalize_message_chunk
3842
from ._chat_provider_types import (
3943
AnthropicMessage,
@@ -65,60 +69,6 @@
6569
)
6670

6771

68-
chatlas_is_installed = importlib.util.find_spec("chatlas") is not None
69-
70-
71-
def is_chatlas_chat_client(client: Any) -> bool:
72-
if not chatlas_is_installed:
73-
return False
74-
import chatlas
75-
76-
return isinstance(client, chatlas.Chat)
77-
78-
79-
@runtime_checkable
80-
class ClientWithState(Protocol):
81-
async def get_state(self) -> Jsonifiable: ...
82-
83-
"""
84-
Retrieve JSON-like representation of chat client state.
85-
86-
This method is used to retrieve the state of the client object when saving a bookmark.
87-
88-
Returns
89-
-------
90-
:
91-
A JSON-like representation of the current state of the client. It is not required to be a JSON string but something that can be serialized to JSON without further conversion.
92-
"""
93-
94-
async def set_state(self, state: Jsonifiable): ...
95-
96-
"""
97-
Method to set the chat client state.
98-
99-
This method is used to restore the state of the client when the app is restored from
100-
a bookmark.
101-
102-
Parameters
103-
----------
104-
state
105-
The value to infer the state from. This value will be the JSON capable value
106-
returned by the `get_state()` method (after a round trip through JSON
107-
serialization and unserialization).
108-
"""
109-
110-
111-
class BookmarkCancelCallback:
112-
def __init__(self, cancel: CancelCallback):
113-
self.cancel = cancel
114-
115-
def __call__(self):
116-
self.cancel()
117-
118-
def tagify(self) -> TagChild:
119-
return ""
120-
121-
12272
# TODO: UserInput might need to be a list of dicts if we want to support multiple
12373
# user input content types
12474
TransformUserInput = Callable[[str], Union[str, None]]
@@ -1476,35 +1426,8 @@ def enable_bookmarking(
14761426

14771427
elif is_chatlas_chat_client(client):
14781428

1479-
from chatlas import Turn as ChatlasTurn
1480-
1481-
# Chatlas specific implementation
1482-
async def get_chatlas_state() -> Jsonifiable:
1483-
turns: list[ChatlasTurn[Any]] = client.get_turns()
1484-
turns_json_str: list[str] = [turn.model_dump_json() for turn in turns]
1485-
return cast(Jsonifiable, turns_json_str)
1486-
1487-
async def set_chatlas_state(value: Jsonifiable) -> None:
1488-
if not isinstance(value, list):
1489-
raise ValueError(
1490-
"Chatlas bookmark value must be a list of JSON strings"
1491-
)
1492-
for v in value:
1493-
if not isinstance(v, str):
1494-
raise ValueError(
1495-
"Chat bookmark value must be a list of strings"
1496-
)
1497-
1498-
turns_json_str = cast(list[str], value)
1499-
1500-
turns: list[ChatlasTurn[Any]] = [
1501-
ChatlasTurn.model_validate_json(turn_json_str)
1502-
for turn_json_str in turns_json_str
1503-
]
1504-
client.set_turns(turns) # pyright: ignore[reportUnknownMemberType]
1505-
1506-
get_state = get_chatlas_state
1507-
set_state = set_chatlas_state
1429+
get_state = get_chatlas_state(client)
1430+
set_state = set_chatlas_state(client)
15081431

15091432
else:
15101433
raise ValueError(

shiny/ui/_chat_bookmark.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import importlib.util
2+
from typing import (
3+
TYPE_CHECKING,
4+
Any,
5+
Awaitable,
6+
Callable,
7+
Protocol,
8+
cast,
9+
runtime_checkable,
10+
)
11+
12+
from htmltools import TagChild
13+
14+
from .._utils import CancelCallback
15+
from ..types import Jsonifiable
16+
17+
if TYPE_CHECKING:
18+
19+
import chatlas
20+
21+
else:
22+
chatlas = object
23+
24+
25+
chatlas_is_installed = importlib.util.find_spec("chatlas") is not None
26+
27+
28+
def is_chatlas_chat_client(client: Any) -> bool:
29+
if not chatlas_is_installed:
30+
return False
31+
import chatlas
32+
33+
return isinstance(client, chatlas.Chat)
34+
35+
36+
@runtime_checkable
37+
class ClientWithState(Protocol):
38+
async def get_state(self) -> Jsonifiable: ...
39+
40+
"""
41+
Retrieve JSON-like representation of chat client state.
42+
43+
This method is used to retrieve the state of the client object when saving a bookmark.
44+
45+
Returns
46+
-------
47+
:
48+
A JSON-like representation of the current state of the client. It is not required to be a JSON string but something that can be serialized to JSON without further conversion.
49+
"""
50+
51+
async def set_state(self, state: Jsonifiable): ...
52+
53+
"""
54+
Method to set the chat client state.
55+
56+
This method is used to restore the state of the client when the app is restored from
57+
a bookmark.
58+
59+
Parameters
60+
----------
61+
state
62+
The value to infer the state from. This value will be the JSON capable value
63+
returned by the `get_state()` method (after a round trip through JSON
64+
serialization and unserialization).
65+
"""
66+
67+
68+
class BookmarkCancelCallback:
69+
def __init__(self, cancel: CancelCallback):
70+
self.cancel = cancel
71+
72+
def __call__(self):
73+
self.cancel()
74+
75+
def tagify(self) -> TagChild:
76+
return ""
77+
78+
79+
# Chatlas specific implementation
80+
def get_chatlas_state(
81+
client: chatlas.Chat[Any, Any],
82+
) -> Callable[[], Awaitable[Jsonifiable]]:
83+
84+
from chatlas import Turn as ChatlasTurn
85+
86+
async def get_state() -> Jsonifiable:
87+
88+
turns: list[ChatlasTurn[Any]] = client.get_turns()
89+
turns_json_str: list[str] = [turn.model_dump_json() for turn in turns]
90+
return cast(Jsonifiable, turns_json_str)
91+
92+
return get_state
93+
94+
95+
def set_chatlas_state(
96+
client: chatlas.Chat[Any, Any],
97+
) -> Callable[[Jsonifiable], Awaitable[None]]:
98+
from chatlas import Turn as ChatlasTurn
99+
100+
async def set_state(value: Jsonifiable) -> None:
101+
102+
if not isinstance(value, list):
103+
raise ValueError("Chatlas bookmark value must be a list of JSON strings")
104+
for v in value:
105+
if not isinstance(v, str):
106+
raise ValueError("Chat bookmark value must be a list of strings")
107+
108+
turns_json_str = cast(list[str], value)
109+
110+
turns: list[ChatlasTurn[Any]] = [
111+
ChatlasTurn.model_validate_json(turn_json_str)
112+
for turn_json_str in turns_json_str
113+
]
114+
client.set_turns(turns) # pyright: ignore[reportUnknownMemberType]
115+
116+
return set_state

0 commit comments

Comments
 (0)