1
1
from enum import Enum
2
- from typing import Any , cast
2
+ from typing import Annotated , Any , Literal , cast , get_args , get_origin , overload
3
3
4
- from pydantic import BaseModel , ConfigDict , Field
4
+ from pydantic import BaseModel , ConfigDict , Field , RootModel
5
5
6
6
from ragbits .chat .interface .forms import UserSettings
7
7
from ragbits .chat .interface .ui_customization import UICustomization
@@ -133,13 +133,217 @@ class ChatContext(BaseModel):
133
133
model_config = ConfigDict (extra = "allow" )
134
134
135
135
136
- class ChatResponse (BaseModel ):
137
- """Container for different types of chat responses."""
136
+ _CHAT_RESPONSE_REGISTRY : dict [ChatResponseType , type [BaseModel ]] = {}
137
+
138
+
139
+ class ChatResponseBase (BaseModel ):
140
+ """Base class for all ChatResponse variants with auto-registration."""
138
141
139
142
type : ChatResponseType
140
- content : (
141
- str | Reference | StateUpdate | LiveUpdate | list [str ] | Image | dict [str , MessageUsage ] | ChunkedContent | None
142
- )
143
+
144
+ def __init_subclass__ (cls , ** kwargs : Any ):
145
+ super ().__init_subclass__ (** kwargs )
146
+ type_ann = cls .model_fields ["type" ].annotation
147
+ origin = get_origin (type_ann )
148
+ value = get_args (type_ann )[0 ] if origin is Literal else getattr (cls , "type" , None )
149
+
150
+ if value is None :
151
+ raise ValueError (f"Cannot determine ChatResponseType for { cls .__name__ } " )
152
+
153
+ _CHAT_RESPONSE_REGISTRY [value ] = cls
154
+
155
+
156
+ class TextChatResponse (ChatResponseBase ):
157
+ """Represents text chat response"""
158
+
159
+ type : Literal [ChatResponseType .TEXT ] = ChatResponseType .TEXT
160
+ content : str
161
+
162
+
163
+ class ReferenceChatResponse (ChatResponseBase ):
164
+ """Represents reference chat response"""
165
+
166
+ type : Literal [ChatResponseType .REFERENCE ] = ChatResponseType .REFERENCE
167
+ content : Reference
168
+
169
+
170
+ class StateUpdateChatResponse (ChatResponseBase ):
171
+ """Represents state update chat response"""
172
+
173
+ type : Literal [ChatResponseType .STATE_UPDATE ] = ChatResponseType .STATE_UPDATE
174
+ content : StateUpdate
175
+
176
+
177
+ class ConversationIdChatResponse (ChatResponseBase ):
178
+ """Represents conversation_id chat response"""
179
+
180
+ type : Literal [ChatResponseType .CONVERSATION_ID ] = ChatResponseType .CONVERSATION_ID
181
+ content : str
182
+
183
+
184
+ class LiveUpdateChatResponse (ChatResponseBase ):
185
+ """Represents live update chat response"""
186
+
187
+ type : Literal [ChatResponseType .LIVE_UPDATE ] = ChatResponseType .LIVE_UPDATE
188
+ content : LiveUpdate
189
+
190
+
191
+ class FollowupMessagesChatResponse (ChatResponseBase ):
192
+ """Represents followup messages chat response"""
193
+
194
+ type : Literal [ChatResponseType .FOLLOWUP_MESSAGES ] = ChatResponseType .FOLLOWUP_MESSAGES
195
+ content : list [str ]
196
+
197
+
198
+ class ImageChatResponse (ChatResponseBase ):
199
+ """Represents image chat response"""
200
+
201
+ type : Literal [ChatResponseType .IMAGE ] = ChatResponseType .IMAGE
202
+ content : Image
203
+
204
+
205
+ class ClearMessageChatResponse (ChatResponseBase ):
206
+ """Represents clear message event"""
207
+
208
+ type : Literal [ChatResponseType .CLEAR_MESSAGE ] = ChatResponseType .CLEAR_MESSAGE
209
+ content : None = None
210
+
211
+
212
+ class UsageChatResponse (ChatResponseBase ):
213
+ """Represents usage chat response"""
214
+
215
+ type : Literal [ChatResponseType .USAGE ] = ChatResponseType .USAGE
216
+ content : dict [str , MessageUsage ]
217
+
218
+
219
+ class MessageIdChatResponse (ChatResponseBase ):
220
+ """Represents message_id chat response"""
221
+
222
+ type : Literal [ChatResponseType .MESSAGE_ID ] = ChatResponseType .MESSAGE_ID
223
+ content : str
224
+
225
+
226
+ class ChunkedContentChatResponse (ChatResponseBase ):
227
+ """Represents chunked_content event that contains chunked event of different type"""
228
+
229
+ type : Literal [ChatResponseType .CHUNKED_CONTENT ] = ChatResponseType .CHUNKED_CONTENT
230
+ content : ChunkedContent
231
+
232
+
233
+ ChatResponseUnion = Annotated [
234
+ TextChatResponse
235
+ | ReferenceChatResponse
236
+ | StateUpdateChatResponse
237
+ | ConversationIdChatResponse
238
+ | LiveUpdateChatResponse
239
+ | FollowupMessagesChatResponse
240
+ | ImageChatResponse
241
+ | ClearMessageChatResponse
242
+ | UsageChatResponse
243
+ | MessageIdChatResponse
244
+ | ChunkedContentChatResponse ,
245
+ Field (discriminator = "type" ),
246
+ ]
247
+
248
+
249
+ class ChatResponse (RootModel [ChatResponseUnion ]):
250
+ """Container for different types of chat responses."""
251
+
252
+ root : ChatResponseUnion
253
+
254
+ @property
255
+ def content (self ) -> object :
256
+ """Returns content of a response, use dedicated `as_*` methods to get type hints."""
257
+ return self .root .content
258
+
259
+ @property
260
+ def type (self ) -> ChatResponseType :
261
+ """Returns type of the ChatResponse"""
262
+ return self .root .type
263
+
264
+ @overload
265
+ def __init__ (
266
+ self ,
267
+ type : Literal [ChatResponseType .TEXT ],
268
+ content : str ,
269
+ ) -> None : ...
270
+ @overload
271
+ def __init__ (
272
+ self ,
273
+ type : Literal [ChatResponseType .REFERENCE ],
274
+ content : Reference ,
275
+ ) -> None : ...
276
+ @overload
277
+ def __init__ (
278
+ self ,
279
+ type : Literal [ChatResponseType .STATE_UPDATE ],
280
+ content : StateUpdate ,
281
+ ) -> None : ...
282
+ @overload
283
+ def __init__ (
284
+ self ,
285
+ type : Literal [ChatResponseType .CONVERSATION_ID ],
286
+ content : str ,
287
+ ) -> None : ...
288
+ @overload
289
+ def __init__ (
290
+ self ,
291
+ type : Literal [ChatResponseType .LIVE_UPDATE ],
292
+ content : LiveUpdate ,
293
+ ) -> None : ...
294
+ @overload
295
+ def __init__ (
296
+ self ,
297
+ type : Literal [ChatResponseType .FOLLOWUP_MESSAGES ],
298
+ content : list [str ],
299
+ ) -> None : ...
300
+ @overload
301
+ def __init__ (
302
+ self ,
303
+ type : Literal [ChatResponseType .IMAGE ],
304
+ content : Image ,
305
+ ) -> None : ...
306
+ @overload
307
+ def __init__ (
308
+ self ,
309
+ type : Literal [ChatResponseType .CLEAR_MESSAGE ],
310
+ content : None ,
311
+ ) -> None : ...
312
+ @overload
313
+ def __init__ (
314
+ self ,
315
+ type : Literal [ChatResponseType .USAGE ],
316
+ content : dict [str , MessageUsage ],
317
+ ) -> None : ...
318
+ @overload
319
+ def __init__ (
320
+ self ,
321
+ type : Literal [ChatResponseType .MESSAGE_ID ],
322
+ content : str ,
323
+ ) -> None : ...
324
+ @overload
325
+ def __init__ (
326
+ self ,
327
+ type : Literal [ChatResponseType .CHUNKED_CONTENT ],
328
+ content : ChunkedContent ,
329
+ ) -> None : ...
330
+ def __init__ (
331
+ self ,
332
+ type : ChatResponseType ,
333
+ content : Any ,
334
+ ) -> None :
335
+ """
336
+ Backward-compatible constructor.
337
+
338
+ Allows creating a ChatResponse directly with:
339
+ ChatResponse(type=ChatResponseType.TEXT, content="hello")
340
+ """
341
+ model_cls = _CHAT_RESPONSE_REGISTRY .get (type )
342
+ if model_cls is None :
343
+ raise ValueError (f"Unsupported ChatResponseType: { type } " )
344
+
345
+ model_instance = model_cls (type = type , content = content )
346
+ super ().__init__ (root = cast (ChatResponseUnion , model_instance ))
143
347
144
348
def as_text (self ) -> str | None :
145
349
"""
@@ -149,7 +353,7 @@ def as_text(self) -> str | None:
149
353
if text := response.as_text():
150
354
print(f"Got text: {text}")
151
355
"""
152
- return str ( self .content ) if self .type == ChatResponseType . TEXT else None
356
+ return self .root . content if isinstance ( self .root , TextChatResponse ) else None
153
357
154
358
def as_reference (self ) -> Reference | None :
155
359
"""
@@ -159,7 +363,7 @@ def as_reference(self) -> Reference | None:
159
363
if ref := response.as_reference():
160
364
print(f"Got reference: {ref.title}")
161
365
"""
162
- return cast ( Reference , self .content ) if self .type == ChatResponseType . REFERENCE else None
366
+ return self .root . content if isinstance ( self .root , ReferenceChatResponse ) else None
163
367
164
368
def as_state_update (self ) -> StateUpdate | None :
165
369
"""
@@ -169,13 +373,13 @@ def as_state_update(self) -> StateUpdate | None:
169
373
if state_update := response.as_state_update():
170
374
state = verify_state(state_update)
171
375
"""
172
- return cast ( StateUpdate , self .content ) if self .type == ChatResponseType . STATE_UPDATE else None
376
+ return self .root . content if isinstance ( self .root , StateUpdateChatResponse ) else None
173
377
174
378
def as_conversation_id (self ) -> str | None :
175
379
"""
176
380
Return the content as ConversationID if this is a conversation id, else None.
177
381
"""
178
- return cast ( str , self .content ) if self .type == ChatResponseType . CONVERSATION_ID else None
382
+ return self .root . content if isinstance ( self .root , ConversationIdChatResponse ) else None
179
383
180
384
def as_live_update (self ) -> LiveUpdate | None :
181
385
"""
@@ -185,7 +389,7 @@ def as_live_update(self) -> LiveUpdate | None:
185
389
if live_update := response.as_live_update():
186
390
print(f"Got live update: {live_update.content.label}")
187
391
"""
188
- return cast ( LiveUpdate , self .content ) if self .type == ChatResponseType . LIVE_UPDATE else None
392
+ return self .root . content if isinstance ( self .root , LiveUpdateChatResponse ) else None
189
393
190
394
def as_followup_messages (self ) -> list [str ] | None :
191
395
"""
@@ -195,25 +399,25 @@ def as_followup_messages(self) -> list[str] | None:
195
399
if followup_messages := response.as_followup_messages():
196
400
print(f"Got followup messages: {followup_messages}")
197
401
"""
198
- return cast ( list [ str ], self .content ) if self .type == ChatResponseType . FOLLOWUP_MESSAGES else None
402
+ return self .root . content if isinstance ( self .root , FollowupMessagesChatResponse ) else None
199
403
200
404
def as_image (self ) -> Image | None :
201
405
"""
202
406
Return the content as Image if this is an image response, else None.
203
407
"""
204
- return cast ( Image , self .content ) if self .type == ChatResponseType . IMAGE else None
408
+ return self .root . content if isinstance ( self .root , ImageChatResponse ) else None
205
409
206
410
def as_clear_message (self ) -> None :
207
411
"""
208
412
Return the content of clear_message response, which is None
209
413
"""
210
- return cast ( None , self .content )
414
+ return self .root . content if isinstance ( self . root , ClearMessageChatResponse ) else None
211
415
212
416
def as_usage (self ) -> dict [str , MessageUsage ] | None :
213
417
"""
214
418
Return the content as dict from model name to Usage if this is an usage response, else None
215
419
"""
216
- return cast ( dict [ str , MessageUsage ], self .content ) if self .type == ChatResponseType . USAGE else None
420
+ return self .root . content if isinstance ( self .root , UsageChatResponse ) else None
217
421
218
422
219
423
class ChatMessageRequest (BaseModel ):
0 commit comments