Skip to content

Commit 14d7f0d

Browse files
committed
refactor: migrate ChatResponse to discriminated union
1 parent 85f6528 commit 14d7f0d

File tree

8 files changed

+335
-84
lines changed

8 files changed

+335
-84
lines changed

packages/ragbits-chat/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# CHANGELOG
22

33
## Unreleased
4+
- Move ChatResponse to union of types (#809)
45
- Refactor chat handlers in the UI to use registry (#805)
56
- Add auth token storage and automatic logout on 401 (#802)
67
- Improve user settings storage when history is disabled (#799)

packages/ragbits-chat/src/ragbits/chat/interface/types.py

Lines changed: 220 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
2-
from typing import Any, cast
2+
from typing import Annotated, Any, Literal, cast, get_args, get_origin, overload
33

4-
from pydantic import BaseModel, ConfigDict, Field
4+
from pydantic import BaseModel, ConfigDict, Field, RootModel
55

66
from ragbits.chat.interface.forms import UserSettings
77
from ragbits.chat.interface.ui_customization import UICustomization
@@ -133,13 +133,217 @@ class ChatContext(BaseModel):
133133
model_config = ConfigDict(extra="allow")
134134

135135

136-
class ChatResponse(BaseModel):
137-
"""Container for different types of chat responses."""
136+
_CHAT_RESPONSE_REGISTRY: dict[ChatResponseType, type[BaseModel]] = {}
137+
138+
139+
class ChatResponseBase(BaseModel):
140+
"""Base class for all ChatResponse variants with auto-registration."""
138141

139142
type: ChatResponseType
140-
content: (
141-
str | Reference | StateUpdate | LiveUpdate | list[str] | Image | dict[str, MessageUsage] | ChunkedContent | None
142-
)
143+
144+
def __init_subclass__(cls, **kwargs: Any):
145+
super().__init_subclass__(**kwargs)
146+
type_ann = cls.model_fields["type"].annotation
147+
origin = get_origin(type_ann)
148+
value = get_args(type_ann)[0] if origin is Literal else getattr(cls, "type", None)
149+
150+
if value is None:
151+
raise ValueError(f"Cannot determine ChatResponseType for {cls.__name__}")
152+
153+
_CHAT_RESPONSE_REGISTRY[value] = cls
154+
155+
156+
class TextChatResponse(ChatResponseBase):
157+
"""Represents text chat response"""
158+
159+
type: Literal[ChatResponseType.TEXT] = ChatResponseType.TEXT
160+
content: str
161+
162+
163+
class ReferenceChatResponse(ChatResponseBase):
164+
"""Represents reference chat response"""
165+
166+
type: Literal[ChatResponseType.REFERENCE] = ChatResponseType.REFERENCE
167+
content: Reference
168+
169+
170+
class StateUpdateChatResponse(ChatResponseBase):
171+
"""Represents state update chat response"""
172+
173+
type: Literal[ChatResponseType.STATE_UPDATE] = ChatResponseType.STATE_UPDATE
174+
content: StateUpdate
175+
176+
177+
class ConversationIdChatResponse(ChatResponseBase):
178+
"""Represents conversation_id chat response"""
179+
180+
type: Literal[ChatResponseType.CONVERSATION_ID] = ChatResponseType.CONVERSATION_ID
181+
content: str
182+
183+
184+
class LiveUpdateChatResponse(ChatResponseBase):
185+
"""Represents live update chat response"""
186+
187+
type: Literal[ChatResponseType.LIVE_UPDATE] = ChatResponseType.LIVE_UPDATE
188+
content: LiveUpdate
189+
190+
191+
class FollowupMessagesChatResponse(ChatResponseBase):
192+
"""Represents followup messages chat response"""
193+
194+
type: Literal[ChatResponseType.FOLLOWUP_MESSAGES] = ChatResponseType.FOLLOWUP_MESSAGES
195+
content: list[str]
196+
197+
198+
class ImageChatResponse(ChatResponseBase):
199+
"""Represents image chat response"""
200+
201+
type: Literal[ChatResponseType.IMAGE] = ChatResponseType.IMAGE
202+
content: Image
203+
204+
205+
class ClearMessageChatResponse(ChatResponseBase):
206+
"""Represents clear message event"""
207+
208+
type: Literal[ChatResponseType.CLEAR_MESSAGE] = ChatResponseType.CLEAR_MESSAGE
209+
content: None = None
210+
211+
212+
class UsageChatResponse(ChatResponseBase):
213+
"""Represents usage chat response"""
214+
215+
type: Literal[ChatResponseType.USAGE] = ChatResponseType.USAGE
216+
content: dict[str, MessageUsage]
217+
218+
219+
class MessageIdChatResponse(ChatResponseBase):
220+
"""Represents message_id chat response"""
221+
222+
type: Literal[ChatResponseType.MESSAGE_ID] = ChatResponseType.MESSAGE_ID
223+
content: str
224+
225+
226+
class ChunkedContentChatResponse(ChatResponseBase):
227+
"""Represents chunked_content event that contains chunked event of different type"""
228+
229+
type: Literal[ChatResponseType.CHUNKED_CONTENT] = ChatResponseType.CHUNKED_CONTENT
230+
content: ChunkedContent
231+
232+
233+
ChatResponseUnion = Annotated[
234+
TextChatResponse
235+
| ReferenceChatResponse
236+
| StateUpdateChatResponse
237+
| ConversationIdChatResponse
238+
| LiveUpdateChatResponse
239+
| FollowupMessagesChatResponse
240+
| ImageChatResponse
241+
| ClearMessageChatResponse
242+
| UsageChatResponse
243+
| MessageIdChatResponse
244+
| ChunkedContentChatResponse,
245+
Field(discriminator="type"),
246+
]
247+
248+
249+
class ChatResponse(RootModel[ChatResponseUnion]):
250+
"""Container for different types of chat responses."""
251+
252+
root: ChatResponseUnion
253+
254+
@property
255+
def content(self) -> object:
256+
"""Returns content of a response, use dedicated `as_*` methods to get type hints."""
257+
return self.root.content
258+
259+
@property
260+
def type(self) -> ChatResponseType:
261+
"""Returns type of the ChatResponse"""
262+
return self.root.type
263+
264+
@overload
265+
def __init__(
266+
self,
267+
type: Literal[ChatResponseType.TEXT],
268+
content: str,
269+
) -> None: ...
270+
@overload
271+
def __init__(
272+
self,
273+
type: Literal[ChatResponseType.REFERENCE],
274+
content: Reference,
275+
) -> None: ...
276+
@overload
277+
def __init__(
278+
self,
279+
type: Literal[ChatResponseType.STATE_UPDATE],
280+
content: StateUpdate,
281+
) -> None: ...
282+
@overload
283+
def __init__(
284+
self,
285+
type: Literal[ChatResponseType.CONVERSATION_ID],
286+
content: str,
287+
) -> None: ...
288+
@overload
289+
def __init__(
290+
self,
291+
type: Literal[ChatResponseType.LIVE_UPDATE],
292+
content: LiveUpdate,
293+
) -> None: ...
294+
@overload
295+
def __init__(
296+
self,
297+
type: Literal[ChatResponseType.FOLLOWUP_MESSAGES],
298+
content: list[str],
299+
) -> None: ...
300+
@overload
301+
def __init__(
302+
self,
303+
type: Literal[ChatResponseType.IMAGE],
304+
content: Image,
305+
) -> None: ...
306+
@overload
307+
def __init__(
308+
self,
309+
type: Literal[ChatResponseType.CLEAR_MESSAGE],
310+
content: None,
311+
) -> None: ...
312+
@overload
313+
def __init__(
314+
self,
315+
type: Literal[ChatResponseType.USAGE],
316+
content: dict[str, MessageUsage],
317+
) -> None: ...
318+
@overload
319+
def __init__(
320+
self,
321+
type: Literal[ChatResponseType.MESSAGE_ID],
322+
content: str,
323+
) -> None: ...
324+
@overload
325+
def __init__(
326+
self,
327+
type: Literal[ChatResponseType.CHUNKED_CONTENT],
328+
content: ChunkedContent,
329+
) -> None: ...
330+
def __init__(
331+
self,
332+
type: ChatResponseType,
333+
content: Any,
334+
) -> None:
335+
"""
336+
Backward-compatible constructor.
337+
338+
Allows creating a ChatResponse directly with:
339+
ChatResponse(type=ChatResponseType.TEXT, content="hello")
340+
"""
341+
model_cls = _CHAT_RESPONSE_REGISTRY.get(type)
342+
if model_cls is None:
343+
raise ValueError(f"Unsupported ChatResponseType: {type}")
344+
345+
model_instance = model_cls(type=type, content=content)
346+
super().__init__(root=cast(ChatResponseUnion, model_instance))
143347

144348
def as_text(self) -> str | None:
145349
"""
@@ -149,7 +353,7 @@ def as_text(self) -> str | None:
149353
if text := response.as_text():
150354
print(f"Got text: {text}")
151355
"""
152-
return str(self.content) if self.type == ChatResponseType.TEXT else None
356+
return self.root.content if isinstance(self.root, TextChatResponse) else None
153357

154358
def as_reference(self) -> Reference | None:
155359
"""
@@ -159,7 +363,7 @@ def as_reference(self) -> Reference | None:
159363
if ref := response.as_reference():
160364
print(f"Got reference: {ref.title}")
161365
"""
162-
return cast(Reference, self.content) if self.type == ChatResponseType.REFERENCE else None
366+
return self.root.content if isinstance(self.root, ReferenceChatResponse) else None
163367

164368
def as_state_update(self) -> StateUpdate | None:
165369
"""
@@ -169,13 +373,13 @@ def as_state_update(self) -> StateUpdate | None:
169373
if state_update := response.as_state_update():
170374
state = verify_state(state_update)
171375
"""
172-
return cast(StateUpdate, self.content) if self.type == ChatResponseType.STATE_UPDATE else None
376+
return self.root.content if isinstance(self.root, StateUpdateChatResponse) else None
173377

174378
def as_conversation_id(self) -> str | None:
175379
"""
176380
Return the content as ConversationID if this is a conversation id, else None.
177381
"""
178-
return cast(str, self.content) if self.type == ChatResponseType.CONVERSATION_ID else None
382+
return self.root.content if isinstance(self.root, ConversationIdChatResponse) else None
179383

180384
def as_live_update(self) -> LiveUpdate | None:
181385
"""
@@ -185,7 +389,7 @@ def as_live_update(self) -> LiveUpdate | None:
185389
if live_update := response.as_live_update():
186390
print(f"Got live update: {live_update.content.label}")
187391
"""
188-
return cast(LiveUpdate, self.content) if self.type == ChatResponseType.LIVE_UPDATE else None
392+
return self.root.content if isinstance(self.root, LiveUpdateChatResponse) else None
189393

190394
def as_followup_messages(self) -> list[str] | None:
191395
"""
@@ -195,25 +399,25 @@ def as_followup_messages(self) -> list[str] | None:
195399
if followup_messages := response.as_followup_messages():
196400
print(f"Got followup messages: {followup_messages}")
197401
"""
198-
return cast(list[str], self.content) if self.type == ChatResponseType.FOLLOWUP_MESSAGES else None
402+
return self.root.content if isinstance(self.root, FollowupMessagesChatResponse) else None
199403

200404
def as_image(self) -> Image | None:
201405
"""
202406
Return the content as Image if this is an image response, else None.
203407
"""
204-
return cast(Image, self.content) if self.type == ChatResponseType.IMAGE else None
408+
return self.root.content if isinstance(self.root, ImageChatResponse) else None
205409

206410
def as_clear_message(self) -> None:
207411
"""
208412
Return the content of clear_message response, which is None
209413
"""
210-
return cast(None, self.content)
414+
return self.root.content if isinstance(self.root, ClearMessageChatResponse) else None
211415

212416
def as_usage(self) -> dict[str, MessageUsage] | None:
213417
"""
214418
Return the content as dict from model name to Usage if this is an usage response, else None
215419
"""
216-
return cast(dict[str, MessageUsage], self.content) if self.type == ChatResponseType.USAGE else None
420+
return self.root.content if isinstance(self.root, UsageChatResponse) else None
217421

218422

219423
class ChatMessageRequest(BaseModel):

packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,21 @@
1010

1111
from pydantic import BaseModel
1212

13-
from ragbits.chat.interface.types import AuthType
13+
from ragbits.chat.interface.types import (
14+
AuthType,
15+
ChatResponse,
16+
ChunkedContentChatResponse,
17+
ClearMessageChatResponse,
18+
ConversationIdChatResponse,
19+
FollowupMessagesChatResponse,
20+
ImageChatResponse,
21+
LiveUpdateChatResponse,
22+
MessageIdChatResponse,
23+
ReferenceChatResponse,
24+
StateUpdateChatResponse,
25+
TextChatResponse,
26+
UsageChatResponse,
27+
)
1428

1529

1630
class RagbitsChatModelProvider:
@@ -93,6 +107,7 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]:
93107
"FeedbackItem": FeedbackItem,
94108
"Image": Image,
95109
"MessageUsage": MessageUsage,
110+
"StateUpdate": StateUpdate,
96111
# Configuration models
97112
"HeaderCustomization": HeaderCustomization,
98113
"UICustomization": UICustomization,
@@ -114,6 +129,19 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]:
114129
"LoginResponse": LoginResponse,
115130
"LogoutRequest": LogoutRequest,
116131
"User": User,
132+
# Chat responses:
133+
"TextChatResponse": TextChatResponse,
134+
"ReferenceChatResponse": ReferenceChatResponse,
135+
"MessageIdChatResponse": MessageIdChatResponse,
136+
"ConversationIdChatResponse": ConversationIdChatResponse,
137+
"StateUpdateChatResponse": StateUpdateChatResponse,
138+
"LiveUpdateChatResponse": LiveUpdateChatResponse,
139+
"FollowupMessagesChatResponse": FollowupMessagesChatResponse,
140+
"ImageChatResponse": ImageChatResponse,
141+
"ClearMessageChatResponse": ClearMessageChatResponse,
142+
"UsageChatResponse": UsageChatResponse,
143+
"ChunkedContentChatResponse": ChunkedContentChatResponse,
144+
"ChatResponse": ChatResponse,
117145
}
118146

119147
return self._models_cache
@@ -163,6 +191,18 @@ def get_categories(self) -> dict[str, list[str]]:
163191
"FeedbackResponse",
164192
"ConfigResponse",
165193
"LoginResponse",
194+
"TextChatResponse",
195+
"ReferenceChatResponse",
196+
"MessageIdChatResponse",
197+
"ConversationIdChatResponse",
198+
"StateUpdateChatResponse",
199+
"LiveUpdateChatResponse",
200+
"FollowupMessagesChatResponse",
201+
"ImageChatResponse",
202+
"ClearMessageChatResponse",
203+
"UsageChatResponse",
204+
"ChunkedContentChatResponse",
205+
"ChatResponse",
166206
],
167207
"requests": [
168208
"ChatRequest",

0 commit comments

Comments
 (0)