|
1 | 1 | import importlib.util |
2 | | -from typing import ( |
3 | | - TYPE_CHECKING, |
4 | | - Any, |
5 | | - Awaitable, |
6 | | - Callable, |
7 | | - Protocol, |
8 | | - cast, |
9 | | - runtime_checkable, |
10 | | -) |
| 2 | +from typing import Any, Awaitable, Callable, Protocol, runtime_checkable |
11 | 3 |
|
12 | 4 | from htmltools import TagChild |
13 | 5 |
|
14 | 6 | from .._utils import CancelCallback |
15 | 7 | from ..types import Jsonifiable |
16 | 8 |
|
17 | | -if TYPE_CHECKING: |
18 | | - |
19 | | - import chatlas |
20 | | - |
21 | | -else: |
22 | | - chatlas = object |
23 | | - |
24 | | - |
25 | 9 | chatlas_is_installed = importlib.util.find_spec("chatlas") is not None |
26 | 10 |
|
27 | 11 |
|
@@ -78,39 +62,34 @@ def tagify(self) -> TagChild: |
78 | 62 |
|
79 | 63 | # Chatlas specific implementation |
80 | 64 | def get_chatlas_state( |
81 | | - client: chatlas.Chat[Any, Any], |
| 65 | + client: Any, |
82 | 66 | ) -> Callable[[], Awaitable[Jsonifiable]]: |
83 | 67 |
|
84 | | - from chatlas import Turn as ChatlasTurn |
| 68 | + from chatlas import Chat, Turn |
| 69 | + |
| 70 | + assert isinstance(client, Chat) |
85 | 71 |
|
86 | 72 | async def get_state() -> Jsonifiable: |
87 | 73 |
|
88 | | - turns: list[ChatlasTurn[Any]] = client.get_turns() |
89 | | - turns_json_str: list[str] = [turn.model_dump_json() for turn in turns] |
90 | | - return cast(Jsonifiable, turns_json_str) |
| 74 | + turns: list[Turn[Any]] = client.get_turns() |
| 75 | + return [turn.model_dump(mode="json") for turn in turns] |
91 | 76 |
|
92 | 77 | return get_state |
93 | 78 |
|
94 | 79 |
|
95 | 80 | def set_chatlas_state( |
96 | | - client: chatlas.Chat[Any, Any], |
| 81 | + client: Any, |
97 | 82 | ) -> Callable[[Jsonifiable], Awaitable[None]]: |
98 | | - from chatlas import Turn as ChatlasTurn |
| 83 | + from chatlas import Chat, Turn |
| 84 | + |
| 85 | + assert isinstance(client, Chat) |
99 | 86 |
|
100 | 87 | async def set_state(value: Jsonifiable) -> None: |
101 | 88 |
|
102 | 89 | if not isinstance(value, list): |
103 | | - raise ValueError("Chatlas bookmark value must be a list of JSON strings") |
104 | | - for v in value: |
105 | | - if not isinstance(v, str): |
106 | | - raise ValueError("Chat bookmark value must be a list of strings") |
107 | | - |
108 | | - turns_json_str = cast(list[str], value) |
| 90 | + raise ValueError("Chatlas bookmark value was not a list of objects") |
109 | 91 |
|
110 | | - turns: list[ChatlasTurn[Any]] = [ |
111 | | - ChatlasTurn.model_validate_json(turn_json_str) |
112 | | - for turn_json_str in turns_json_str |
113 | | - ] |
| 92 | + turns: list[Turn[Any]] = [Turn.model_validate(turn_obj) for turn_obj in value] |
114 | 93 | client.set_turns(turns) # pyright: ignore[reportUnknownMemberType] |
115 | 94 |
|
116 | 95 | return set_state |
0 commit comments