1
1
from enum import Enum
2
- from typing import Any , cast
2
+ from typing import Annotated , Any , Literal , cast , get_args , get_origin , overload
3
3
4
- from pydantic import BaseModel , ConfigDict , Field
4
+ from pydantic import BaseModel , ConfigDict , Field , RootModel
5
5
6
6
from ragbits .chat .auth .types import User
7
7
from ragbits .chat .interface .forms import UserSettings
@@ -135,13 +135,217 @@ class ChatContext(BaseModel):
135
135
model_config = ConfigDict (extra = "allow" )
136
136
137
137
138
- class ChatResponse (BaseModel ):
139
- """Container for different types of chat responses."""
138
+ _CHAT_RESPONSE_REGISTRY : dict [ChatResponseType , type [BaseModel ]] = {}
139
+
140
+
141
+ class ChatResponseBase (BaseModel ):
142
+ """Base class for all ChatResponse variants with auto-registration."""
140
143
141
144
type : ChatResponseType
142
- content : (
143
- str | Reference | StateUpdate | LiveUpdate | list [str ] | Image | dict [str , MessageUsage ] | ChunkedContent | None
144
- )
145
+
146
+ def __init_subclass__ (cls , ** kwargs : Any ):
147
+ super ().__init_subclass__ (** kwargs )
148
+ type_ann = cls .model_fields ["type" ].annotation
149
+ origin = get_origin (type_ann )
150
+ value = get_args (type_ann )[0 ] if origin is Literal else getattr (cls , "type" , None )
151
+
152
+ if value is None :
153
+ raise ValueError (f"Cannot determine ChatResponseType for { cls .__name__ } " )
154
+
155
+ _CHAT_RESPONSE_REGISTRY [value ] = cls
156
+
157
+
158
+ class TextChatResponse (ChatResponseBase ):
159
+ """Represents text chat response"""
160
+
161
+ type : Literal [ChatResponseType .TEXT ] = ChatResponseType .TEXT
162
+ content : str
163
+
164
+
165
+ class ReferenceChatResponse (ChatResponseBase ):
166
+ """Represents reference chat response"""
167
+
168
+ type : Literal [ChatResponseType .REFERENCE ] = ChatResponseType .REFERENCE
169
+ content : Reference
170
+
171
+
172
+ class StateUpdateChatResponse (ChatResponseBase ):
173
+ """Represents state update chat response"""
174
+
175
+ type : Literal [ChatResponseType .STATE_UPDATE ] = ChatResponseType .STATE_UPDATE
176
+ content : StateUpdate
177
+
178
+
179
+ class ConversationIdChatResponse (ChatResponseBase ):
180
+ """Represents conversation_id chat response"""
181
+
182
+ type : Literal [ChatResponseType .CONVERSATION_ID ] = ChatResponseType .CONVERSATION_ID
183
+ content : str
184
+
185
+
186
+ class LiveUpdateChatResponse (ChatResponseBase ):
187
+ """Represents live update chat response"""
188
+
189
+ type : Literal [ChatResponseType .LIVE_UPDATE ] = ChatResponseType .LIVE_UPDATE
190
+ content : LiveUpdate
191
+
192
+
193
+ class FollowupMessagesChatResponse (ChatResponseBase ):
194
+ """Represents followup messages chat response"""
195
+
196
+ type : Literal [ChatResponseType .FOLLOWUP_MESSAGES ] = ChatResponseType .FOLLOWUP_MESSAGES
197
+ content : list [str ]
198
+
199
+
200
+ class ImageChatResponse (ChatResponseBase ):
201
+ """Represents image chat response"""
202
+
203
+ type : Literal [ChatResponseType .IMAGE ] = ChatResponseType .IMAGE
204
+ content : Image
205
+
206
+
207
+ class ClearMessageChatResponse (ChatResponseBase ):
208
+ """Represents clear message event"""
209
+
210
+ type : Literal [ChatResponseType .CLEAR_MESSAGE ] = ChatResponseType .CLEAR_MESSAGE
211
+ content : None = None
212
+
213
+
214
+ class UsageChatResponse (ChatResponseBase ):
215
+ """Represents usage chat response"""
216
+
217
+ type : Literal [ChatResponseType .USAGE ] = ChatResponseType .USAGE
218
+ content : dict [str , MessageUsage ]
219
+
220
+
221
+ class MessageIdChatResponse (ChatResponseBase ):
222
+ """Represents message_id chat response"""
223
+
224
+ type : Literal [ChatResponseType .MESSAGE_ID ] = ChatResponseType .MESSAGE_ID
225
+ content : str
226
+
227
+
228
+ class ChunkedContentChatResponse (ChatResponseBase ):
229
+ """Represents chunked_content event that contains chunked event of different type"""
230
+
231
+ type : Literal [ChatResponseType .CHUNKED_CONTENT ] = ChatResponseType .CHUNKED_CONTENT
232
+ content : ChunkedContent
233
+
234
+
235
+ ChatResponseUnion = Annotated [
236
+ TextChatResponse
237
+ | ReferenceChatResponse
238
+ | StateUpdateChatResponse
239
+ | ConversationIdChatResponse
240
+ | LiveUpdateChatResponse
241
+ | FollowupMessagesChatResponse
242
+ | ImageChatResponse
243
+ | ClearMessageChatResponse
244
+ | UsageChatResponse
245
+ | MessageIdChatResponse
246
+ | ChunkedContentChatResponse ,
247
+ Field (discriminator = "type" ),
248
+ ]
249
+
250
+
251
+ class ChatResponse (RootModel [ChatResponseUnion ]):
252
+ """Container for different types of chat responses."""
253
+
254
+ root : ChatResponseUnion
255
+
256
+ @property
257
+ def content (self ) -> object :
258
+ """Returns content of a response, use dedicated `as_*` methods to get type hints."""
259
+ return self .root .content
260
+
261
+ @property
262
+ def type (self ) -> ChatResponseType :
263
+ """Returns type of the ChatResponse"""
264
+ return self .root .type
265
+
266
+ @overload
267
+ def __init__ (
268
+ self ,
269
+ type : Literal [ChatResponseType .TEXT ],
270
+ content : str ,
271
+ ) -> None : ...
272
+ @overload
273
+ def __init__ (
274
+ self ,
275
+ type : Literal [ChatResponseType .REFERENCE ],
276
+ content : Reference ,
277
+ ) -> None : ...
278
+ @overload
279
+ def __init__ (
280
+ self ,
281
+ type : Literal [ChatResponseType .STATE_UPDATE ],
282
+ content : StateUpdate ,
283
+ ) -> None : ...
284
+ @overload
285
+ def __init__ (
286
+ self ,
287
+ type : Literal [ChatResponseType .CONVERSATION_ID ],
288
+ content : str ,
289
+ ) -> None : ...
290
+ @overload
291
+ def __init__ (
292
+ self ,
293
+ type : Literal [ChatResponseType .LIVE_UPDATE ],
294
+ content : LiveUpdate ,
295
+ ) -> None : ...
296
+ @overload
297
+ def __init__ (
298
+ self ,
299
+ type : Literal [ChatResponseType .FOLLOWUP_MESSAGES ],
300
+ content : list [str ],
301
+ ) -> None : ...
302
+ @overload
303
+ def __init__ (
304
+ self ,
305
+ type : Literal [ChatResponseType .IMAGE ],
306
+ content : Image ,
307
+ ) -> None : ...
308
+ @overload
309
+ def __init__ (
310
+ self ,
311
+ type : Literal [ChatResponseType .CLEAR_MESSAGE ],
312
+ content : None ,
313
+ ) -> None : ...
314
+ @overload
315
+ def __init__ (
316
+ self ,
317
+ type : Literal [ChatResponseType .USAGE ],
318
+ content : dict [str , MessageUsage ],
319
+ ) -> None : ...
320
+ @overload
321
+ def __init__ (
322
+ self ,
323
+ type : Literal [ChatResponseType .MESSAGE_ID ],
324
+ content : str ,
325
+ ) -> None : ...
326
+ @overload
327
+ def __init__ (
328
+ self ,
329
+ type : Literal [ChatResponseType .CHUNKED_CONTENT ],
330
+ content : ChunkedContent ,
331
+ ) -> None : ...
332
+ def __init__ (
333
+ self ,
334
+ type : ChatResponseType ,
335
+ content : Any ,
336
+ ) -> None :
337
+ """
338
+ Backward-compatible constructor.
339
+
340
+ Allows creating a ChatResponse directly with:
341
+ ChatResponse(type=ChatResponseType.TEXT, content="hello")
342
+ """
343
+ model_cls = _CHAT_RESPONSE_REGISTRY .get (type )
344
+ if model_cls is None :
345
+ raise ValueError (f"Unsupported ChatResponseType: { type } " )
346
+
347
+ model_instance = model_cls (type = type , content = content )
348
+ super ().__init__ (root = cast (ChatResponseUnion , model_instance ))
145
349
146
350
def as_text (self ) -> str | None :
147
351
"""
@@ -151,7 +355,7 @@ def as_text(self) -> str | None:
151
355
if text := response.as_text():
152
356
print(f"Got text: {text}")
153
357
"""
154
- return str ( self .content ) if self .type == ChatResponseType . TEXT else None
358
+ return self .root . content if isinstance ( self .root , TextChatResponse ) else None
155
359
156
360
def as_reference (self ) -> Reference | None :
157
361
"""
@@ -161,7 +365,7 @@ def as_reference(self) -> Reference | None:
161
365
if ref := response.as_reference():
162
366
print(f"Got reference: {ref.title}")
163
367
"""
164
- return cast ( Reference , self .content ) if self .type == ChatResponseType . REFERENCE else None
368
+ return self .root . content if isinstance ( self .root , ReferenceChatResponse ) else None
165
369
166
370
def as_state_update (self ) -> StateUpdate | None :
167
371
"""
@@ -171,13 +375,13 @@ def as_state_update(self) -> StateUpdate | None:
171
375
if state_update := response.as_state_update():
172
376
state = verify_state(state_update)
173
377
"""
174
- return cast ( StateUpdate , self .content ) if self .type == ChatResponseType . STATE_UPDATE else None
378
+ return self .root . content if isinstance ( self .root , StateUpdateChatResponse ) else None
175
379
176
380
def as_conversation_id (self ) -> str | None :
177
381
"""
178
382
Return the content as ConversationID if this is a conversation id, else None.
179
383
"""
180
- return cast ( str , self .content ) if self .type == ChatResponseType . CONVERSATION_ID else None
384
+ return self .root . content if isinstance ( self .root , ConversationIdChatResponse ) else None
181
385
182
386
def as_live_update (self ) -> LiveUpdate | None :
183
387
"""
@@ -187,7 +391,7 @@ def as_live_update(self) -> LiveUpdate | None:
187
391
if live_update := response.as_live_update():
188
392
print(f"Got live update: {live_update.content.label}")
189
393
"""
190
- return cast ( LiveUpdate , self .content ) if self .type == ChatResponseType . LIVE_UPDATE else None
394
+ return self .root . content if isinstance ( self .root , LiveUpdateChatResponse ) else None
191
395
192
396
def as_followup_messages (self ) -> list [str ] | None :
193
397
"""
@@ -197,25 +401,25 @@ def as_followup_messages(self) -> list[str] | None:
197
401
if followup_messages := response.as_followup_messages():
198
402
print(f"Got followup messages: {followup_messages}")
199
403
"""
200
- return cast ( list [ str ], self .content ) if self .type == ChatResponseType . FOLLOWUP_MESSAGES else None
404
+ return self .root . content if isinstance ( self .root , FollowupMessagesChatResponse ) else None
201
405
202
406
def as_image (self ) -> Image | None :
203
407
"""
204
408
Return the content as Image if this is an image response, else None.
205
409
"""
206
- return cast ( Image , self .content ) if self .type == ChatResponseType . IMAGE else None
410
+ return self .root . content if isinstance ( self .root , ImageChatResponse ) else None
207
411
208
412
def as_clear_message (self ) -> None :
209
413
"""
210
414
Return the content of clear_message response, which is None
211
415
"""
212
- return cast ( None , self .content )
416
+ return self .root . content if isinstance ( self . root , ClearMessageChatResponse ) else None
213
417
214
418
def as_usage (self ) -> dict [str , MessageUsage ] | None :
215
419
"""
216
420
Return the content as dict from model name to Usage if this is an usage response, else None
217
421
"""
218
- return cast ( dict [ str , MessageUsage ], self .content ) if self .type == ChatResponseType . USAGE else None
422
+ return self .root . content if isinstance ( self .root , UsageChatResponse ) else None
219
423
220
424
221
425
class ChatMessageRequest (BaseModel ):
0 commit comments