|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import importlib.util |
4 | 3 | import inspect |
5 | 4 | from contextlib import asynccontextmanager |
6 | 5 | from typing import ( |
|
12 | 11 | Iterable, |
13 | 12 | Literal, |
14 | 13 | Optional, |
15 | | - Protocol, |
16 | 14 | Sequence, |
17 | 15 | Tuple, |
18 | 16 | Union, |
19 | 17 | cast, |
20 | 18 | overload, |
21 | | - runtime_checkable, |
22 | 19 | ) |
23 | 20 | from weakref import WeakValueDictionary |
24 | 21 |
|
|
34 | 31 | from ..session import get_current_session, require_active_session, session_context |
35 | 32 | from ..types import MISSING, MISSING_TYPE, Jsonifiable, NotifyException |
36 | 33 | from ..ui.css import CssUnit, as_css_unit |
| 34 | +from ._chat_bookmark import ( |
| 35 | + BookmarkCancelCallback, |
| 36 | + ClientWithState, |
| 37 | + get_chatlas_state, |
| 38 | + is_chatlas_chat_client, |
| 39 | + set_chatlas_state, |
| 40 | +) |
37 | 41 | from ._chat_normalize import normalize_message, normalize_message_chunk |
38 | 42 | from ._chat_provider_types import ( |
39 | 43 | AnthropicMessage, |
|
65 | 69 | ) |
66 | 70 |
|
67 | 71 |
|
68 | | -chatlas_is_installed = importlib.util.find_spec("chatlas") is not None |
69 | | - |
70 | | - |
71 | | -def is_chatlas_chat_client(client: Any) -> bool: |
72 | | - if not chatlas_is_installed: |
73 | | - return False |
74 | | - import chatlas |
75 | | - |
76 | | - return isinstance(client, chatlas.Chat) |
77 | | - |
78 | | - |
79 | | -@runtime_checkable |
80 | | -class ClientWithState(Protocol): |
81 | | - async def get_state(self) -> Jsonifiable: ... |
82 | | - |
83 | | - """ |
84 | | - Retrieve JSON-like representation of chat client state. |
85 | | -
|
86 | | - This method is used to retrieve the state of the client object when saving a bookmark. |
87 | | -
|
88 | | - Returns |
89 | | - ------- |
90 | | - : |
91 | | - A JSON-like representation of the current state of the client. It is not required to be a JSON string but something that can be serialized to JSON without further conversion. |
92 | | - """ |
93 | | - |
94 | | - async def set_state(self, state: Jsonifiable): ... |
95 | | - |
96 | | - """ |
97 | | - Method to set the chat client state. |
98 | | -
|
99 | | - This method is used to restore the state of the client when the app is restored from |
100 | | - a bookmark. |
101 | | -
|
102 | | - Parameters |
103 | | - ---------- |
104 | | - state |
105 | | - The value to infer the state from. This value will be the JSON capable value |
106 | | - returned by the `get_state()` method (after a round trip through JSON |
107 | | - serialization and unserialization). |
108 | | - """ |
109 | | - |
110 | | - |
111 | | -class BookmarkCancelCallback: |
112 | | - def __init__(self, cancel: CancelCallback): |
113 | | - self.cancel = cancel |
114 | | - |
115 | | - def __call__(self): |
116 | | - self.cancel() |
117 | | - |
118 | | - def tagify(self) -> TagChild: |
119 | | - return "" |
120 | | - |
121 | | - |
122 | 72 | # TODO: UserInput might need to be a list of dicts if we want to support multiple |
123 | 73 | # user input content types |
124 | 74 | TransformUserInput = Callable[[str], Union[str, None]] |
@@ -1476,35 +1426,8 @@ def enable_bookmarking( |
1476 | 1426 |
|
1477 | 1427 | elif is_chatlas_chat_client(client): |
1478 | 1428 |
|
1479 | | - from chatlas import Turn as ChatlasTurn |
1480 | | - |
1481 | | - # Chatlas specific implementation |
1482 | | - async def get_chatlas_state() -> Jsonifiable: |
1483 | | - turns: list[ChatlasTurn[Any]] = client.get_turns() |
1484 | | - turns_json_str: list[str] = [turn.model_dump_json() for turn in turns] |
1485 | | - return cast(Jsonifiable, turns_json_str) |
1486 | | - |
1487 | | - async def set_chatlas_state(value: Jsonifiable) -> None: |
1488 | | - if not isinstance(value, list): |
1489 | | - raise ValueError( |
1490 | | - "Chatlas bookmark value must be a list of JSON strings" |
1491 | | - ) |
1492 | | - for v in value: |
1493 | | - if not isinstance(v, str): |
1494 | | - raise ValueError( |
1495 | | - "Chat bookmark value must be a list of strings" |
1496 | | - ) |
1497 | | - |
1498 | | - turns_json_str = cast(list[str], value) |
1499 | | - |
1500 | | - turns: list[ChatlasTurn[Any]] = [ |
1501 | | - ChatlasTurn.model_validate_json(turn_json_str) |
1502 | | - for turn_json_str in turns_json_str |
1503 | | - ] |
1504 | | - client.set_turns(turns) # pyright: ignore[reportUnknownMemberType] |
1505 | | - |
1506 | | - get_state = get_chatlas_state |
1507 | | - set_state = set_chatlas_state |
| 1429 | + get_state = get_chatlas_state(client) |
| 1430 | + set_state = set_chatlas_state(client) |
1508 | 1431 |
|
1509 | 1432 | else: |
1510 | 1433 | raise ValueError( |
|
0 commit comments