Skip to content

Commit ab68b94

Browse files
committed
refactor: migrate ChatResponse to discriminated union
1 parent ee680a4 commit ab68b94

File tree

9 files changed

+337
-85
lines changed

9 files changed

+337
-85
lines changed

packages/ragbits-chat/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Move ChatResponse to union of types (#809)
6+
57
## 1.3.0 (2025-09-11)
68

79
### Changed

packages/ragbits-chat/src/ragbits/chat/interface/types.py

Lines changed: 220 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
2-
from typing import Any, cast
2+
from typing import Annotated, Any, Literal, cast, get_args, get_origin, overload
33

4-
from pydantic import BaseModel, ConfigDict, Field
4+
from pydantic import BaseModel, ConfigDict, Field, RootModel
55

66
from ragbits.chat.auth.types import User
77
from ragbits.chat.interface.forms import UserSettings
@@ -135,13 +135,217 @@ class ChatContext(BaseModel):
135135
model_config = ConfigDict(extra="allow")
136136

137137

138-
class ChatResponse(BaseModel):
139-
"""Container for different types of chat responses."""
138+
_CHAT_RESPONSE_REGISTRY: dict[ChatResponseType, type[BaseModel]] = {}
139+
140+
141+
class ChatResponseBase(BaseModel):
142+
"""Base class for all ChatResponse variants with auto-registration."""
140143

141144
type: ChatResponseType
142-
content: (
143-
str | Reference | StateUpdate | LiveUpdate | list[str] | Image | dict[str, MessageUsage] | ChunkedContent | None
144-
)
145+
146+
def __init_subclass__(cls, **kwargs: Any):
147+
super().__init_subclass__(**kwargs)
148+
type_ann = cls.model_fields["type"].annotation
149+
origin = get_origin(type_ann)
150+
value = get_args(type_ann)[0] if origin is Literal else getattr(cls, "type", None)
151+
152+
if value is None:
153+
raise ValueError(f"Cannot determine ChatResponseType for {cls.__name__}")
154+
155+
_CHAT_RESPONSE_REGISTRY[value] = cls
156+
157+
158+
class TextChatResponse(ChatResponseBase):
159+
"""Represents text chat response"""
160+
161+
type: Literal[ChatResponseType.TEXT] = ChatResponseType.TEXT
162+
content: str
163+
164+
165+
class ReferenceChatResponse(ChatResponseBase):
166+
"""Represents reference chat response"""
167+
168+
type: Literal[ChatResponseType.REFERENCE] = ChatResponseType.REFERENCE
169+
content: Reference
170+
171+
172+
class StateUpdateChatResponse(ChatResponseBase):
173+
"""Represents state update chat response"""
174+
175+
type: Literal[ChatResponseType.STATE_UPDATE] = ChatResponseType.STATE_UPDATE
176+
content: StateUpdate
177+
178+
179+
class ConversationIdChatResponse(ChatResponseBase):
180+
"""Represents conversation_id chat response"""
181+
182+
type: Literal[ChatResponseType.CONVERSATION_ID] = ChatResponseType.CONVERSATION_ID
183+
content: str
184+
185+
186+
class LiveUpdateChatResponse(ChatResponseBase):
187+
"""Represents live update chat response"""
188+
189+
type: Literal[ChatResponseType.LIVE_UPDATE] = ChatResponseType.LIVE_UPDATE
190+
content: LiveUpdate
191+
192+
193+
class FollowupMessagesChatResponse(ChatResponseBase):
194+
"""Represents followup messages chat response"""
195+
196+
type: Literal[ChatResponseType.FOLLOWUP_MESSAGES] = ChatResponseType.FOLLOWUP_MESSAGES
197+
content: list[str]
198+
199+
200+
class ImageChatResponse(ChatResponseBase):
201+
"""Represents image chat response"""
202+
203+
type: Literal[ChatResponseType.IMAGE] = ChatResponseType.IMAGE
204+
content: Image
205+
206+
207+
class ClearMessageChatResponse(ChatResponseBase):
208+
"""Represents clear message event"""
209+
210+
type: Literal[ChatResponseType.CLEAR_MESSAGE] = ChatResponseType.CLEAR_MESSAGE
211+
content: None = None
212+
213+
214+
class UsageChatResponse(ChatResponseBase):
215+
"""Represents usage chat response"""
216+
217+
type: Literal[ChatResponseType.USAGE] = ChatResponseType.USAGE
218+
content: dict[str, MessageUsage]
219+
220+
221+
class MessageIdChatResponse(ChatResponseBase):
222+
"""Represents message_id chat response"""
223+
224+
type: Literal[ChatResponseType.MESSAGE_ID] = ChatResponseType.MESSAGE_ID
225+
content: str
226+
227+
228+
class ChunkedContentChatResponse(ChatResponseBase):
229+
"""Represents chunked_content event that contains chunked event of different type"""
230+
231+
type: Literal[ChatResponseType.CHUNKED_CONTENT] = ChatResponseType.CHUNKED_CONTENT
232+
content: ChunkedContent
233+
234+
235+
ChatResponseUnion = Annotated[
236+
TextChatResponse
237+
| ReferenceChatResponse
238+
| StateUpdateChatResponse
239+
| ConversationIdChatResponse
240+
| LiveUpdateChatResponse
241+
| FollowupMessagesChatResponse
242+
| ImageChatResponse
243+
| ClearMessageChatResponse
244+
| UsageChatResponse
245+
| MessageIdChatResponse
246+
| ChunkedContentChatResponse,
247+
Field(discriminator="type"),
248+
]
249+
250+
251+
class ChatResponse(RootModel[ChatResponseUnion]):
252+
"""Container for different types of chat responses."""
253+
254+
root: ChatResponseUnion
255+
256+
@property
257+
def content(self) -> object:
258+
"""Returns content of a response, use dedicated `as_*` methods to get type hints."""
259+
return self.root.content
260+
261+
@property
262+
def type(self) -> ChatResponseType:
263+
"""Returns type of the ChatResponse"""
264+
return self.root.type
265+
266+
@overload
267+
def __init__(
268+
self,
269+
type: Literal[ChatResponseType.TEXT],
270+
content: str,
271+
) -> None: ...
272+
@overload
273+
def __init__(
274+
self,
275+
type: Literal[ChatResponseType.REFERENCE],
276+
content: Reference,
277+
) -> None: ...
278+
@overload
279+
def __init__(
280+
self,
281+
type: Literal[ChatResponseType.STATE_UPDATE],
282+
content: StateUpdate,
283+
) -> None: ...
284+
@overload
285+
def __init__(
286+
self,
287+
type: Literal[ChatResponseType.CONVERSATION_ID],
288+
content: str,
289+
) -> None: ...
290+
@overload
291+
def __init__(
292+
self,
293+
type: Literal[ChatResponseType.LIVE_UPDATE],
294+
content: LiveUpdate,
295+
) -> None: ...
296+
@overload
297+
def __init__(
298+
self,
299+
type: Literal[ChatResponseType.FOLLOWUP_MESSAGES],
300+
content: list[str],
301+
) -> None: ...
302+
@overload
303+
def __init__(
304+
self,
305+
type: Literal[ChatResponseType.IMAGE],
306+
content: Image,
307+
) -> None: ...
308+
@overload
309+
def __init__(
310+
self,
311+
type: Literal[ChatResponseType.CLEAR_MESSAGE],
312+
content: None,
313+
) -> None: ...
314+
@overload
315+
def __init__(
316+
self,
317+
type: Literal[ChatResponseType.USAGE],
318+
content: dict[str, MessageUsage],
319+
) -> None: ...
320+
@overload
321+
def __init__(
322+
self,
323+
type: Literal[ChatResponseType.MESSAGE_ID],
324+
content: str,
325+
) -> None: ...
326+
@overload
327+
def __init__(
328+
self,
329+
type: Literal[ChatResponseType.CHUNKED_CONTENT],
330+
content: ChunkedContent,
331+
) -> None: ...
332+
def __init__(
333+
self,
334+
type: ChatResponseType,
335+
content: Any,
336+
) -> None:
337+
"""
338+
Backward-compatible constructor.
339+
340+
Allows creating a ChatResponse directly with:
341+
ChatResponse(type=ChatResponseType.TEXT, content="hello")
342+
"""
343+
model_cls = _CHAT_RESPONSE_REGISTRY.get(type)
344+
if model_cls is None:
345+
raise ValueError(f"Unsupported ChatResponseType: {type}")
346+
347+
model_instance = model_cls(type=type, content=content)
348+
super().__init__(root=cast(ChatResponseUnion, model_instance))
145349

146350
def as_text(self) -> str | None:
147351
"""
@@ -151,7 +355,7 @@ def as_text(self) -> str | None:
151355
if text := response.as_text():
152356
print(f"Got text: {text}")
153357
"""
154-
return str(self.content) if self.type == ChatResponseType.TEXT else None
358+
return self.root.content if isinstance(self.root, TextChatResponse) else None
155359

156360
def as_reference(self) -> Reference | None:
157361
"""
@@ -161,7 +365,7 @@ def as_reference(self) -> Reference | None:
161365
if ref := response.as_reference():
162366
print(f"Got reference: {ref.title}")
163367
"""
164-
return cast(Reference, self.content) if self.type == ChatResponseType.REFERENCE else None
368+
return self.root.content if isinstance(self.root, ReferenceChatResponse) else None
165369

166370
def as_state_update(self) -> StateUpdate | None:
167371
"""
@@ -171,13 +375,13 @@ def as_state_update(self) -> StateUpdate | None:
171375
if state_update := response.as_state_update():
172376
state = verify_state(state_update)
173377
"""
174-
return cast(StateUpdate, self.content) if self.type == ChatResponseType.STATE_UPDATE else None
378+
return self.root.content if isinstance(self.root, StateUpdateChatResponse) else None
175379

176380
def as_conversation_id(self) -> str | None:
177381
"""
178382
Return the content as ConversationID if this is a conversation id, else None.
179383
"""
180-
return cast(str, self.content) if self.type == ChatResponseType.CONVERSATION_ID else None
384+
return self.root.content if isinstance(self.root, ConversationIdChatResponse) else None
181385

182386
def as_live_update(self) -> LiveUpdate | None:
183387
"""
@@ -187,7 +391,7 @@ def as_live_update(self) -> LiveUpdate | None:
187391
if live_update := response.as_live_update():
188392
print(f"Got live update: {live_update.content.label}")
189393
"""
190-
return cast(LiveUpdate, self.content) if self.type == ChatResponseType.LIVE_UPDATE else None
394+
return self.root.content if isinstance(self.root, LiveUpdateChatResponse) else None
191395

192396
def as_followup_messages(self) -> list[str] | None:
193397
"""
@@ -197,25 +401,25 @@ def as_followup_messages(self) -> list[str] | None:
197401
if followup_messages := response.as_followup_messages():
198402
print(f"Got followup messages: {followup_messages}")
199403
"""
200-
return cast(list[str], self.content) if self.type == ChatResponseType.FOLLOWUP_MESSAGES else None
404+
return self.root.content if isinstance(self.root, FollowupMessagesChatResponse) else None
201405

202406
def as_image(self) -> Image | None:
203407
"""
204408
Return the content as Image if this is an image response, else None.
205409
"""
206-
return cast(Image, self.content) if self.type == ChatResponseType.IMAGE else None
410+
return self.root.content if isinstance(self.root, ImageChatResponse) else None
207411

208412
def as_clear_message(self) -> None:
209413
"""
210414
Return the content of clear_message response, which is None
211415
"""
212-
return cast(None, self.content)
416+
return self.root.content if isinstance(self.root, ClearMessageChatResponse) else None
213417

214418
def as_usage(self) -> dict[str, MessageUsage] | None:
215419
"""
216420
Return the content as dict from model name to Usage if this is an usage response, else None
217421
"""
218-
return cast(dict[str, MessageUsage], self.content) if self.type == ChatResponseType.USAGE else None
422+
return self.root.content if isinstance(self.root, UsageChatResponse) else None
219423

220424

221425
class ChatMessageRequest(BaseModel):

packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,21 @@
1010

1111
from pydantic import BaseModel
1212

13-
from ragbits.chat.interface.types import AuthType
13+
from ragbits.chat.interface.types import (
14+
AuthType,
15+
ChatResponse,
16+
ChunkedContentChatResponse,
17+
ClearMessageChatResponse,
18+
ConversationIdChatResponse,
19+
FollowupMessagesChatResponse,
20+
ImageChatResponse,
21+
LiveUpdateChatResponse,
22+
MessageIdChatResponse,
23+
ReferenceChatResponse,
24+
StateUpdateChatResponse,
25+
TextChatResponse,
26+
UsageChatResponse,
27+
)
1428

1529

1630
class RagbitsChatModelProvider:
@@ -93,6 +107,7 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]:
93107
"FeedbackItem": FeedbackItem,
94108
"Image": Image,
95109
"MessageUsage": MessageUsage,
110+
"StateUpdate": StateUpdate,
96111
# Configuration models
97112
"HeaderCustomization": HeaderCustomization,
98113
"UICustomization": UICustomization,
@@ -114,6 +129,19 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]:
114129
"LoginResponse": LoginResponse,
115130
"LogoutRequest": LogoutRequest,
116131
"User": User,
132+
# Chat responses:
133+
"TextChatResponse": TextChatResponse,
134+
"ReferenceChatResponse": ReferenceChatResponse,
135+
"MessageIdChatResponse": MessageIdChatResponse,
136+
"ConversationIdChatResponse": ConversationIdChatResponse,
137+
"StateUpdateChatResponse": StateUpdateChatResponse,
138+
"LiveUpdateChatResponse": LiveUpdateChatResponse,
139+
"FollowupMessagesChatResponse": FollowupMessagesChatResponse,
140+
"ImageChatResponse": ImageChatResponse,
141+
"ClearMessageChatResponse": ClearMessageChatResponse,
142+
"UsageChatResponse": UsageChatResponse,
143+
"ChunkedContentChatResponse": ChunkedContentChatResponse,
144+
"ChatResponse": ChatResponse,
117145
}
118146

119147
return self._models_cache
@@ -163,6 +191,18 @@ def get_categories(self) -> dict[str, list[str]]:
163191
"FeedbackResponse",
164192
"ConfigResponse",
165193
"LoginResponse",
194+
"TextChatResponse",
195+
"ReferenceChatResponse",
196+
"MessageIdChatResponse",
197+
"ConversationIdChatResponse",
198+
"StateUpdateChatResponse",
199+
"LiveUpdateChatResponse",
200+
"FollowupMessagesChatResponse",
201+
"ImageChatResponse",
202+
"ClearMessageChatResponse",
203+
"UsageChatResponse",
204+
"ChunkedContentChatResponse",
205+
"ChatResponse",
166206
],
167207
"requests": [
168208
"ChatRequest",

0 commit comments

Comments
 (0)