Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/ragbits-chat/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Move ChatResponse to union of types (#809)

## 1.3.0 (2025-09-11)

### Changed
Expand Down
236 changes: 220 additions & 16 deletions packages/ragbits-chat/src/ragbits/chat/interface/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, cast
from typing import Annotated, Any, Literal, cast, get_args, get_origin, overload

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, RootModel

from ragbits.chat.auth.types import User
from ragbits.chat.interface.forms import UserSettings
Expand Down Expand Up @@ -135,13 +135,217 @@ class ChatContext(BaseModel):
model_config = ConfigDict(extra="allow")


class ChatResponse(BaseModel):
"""Container for different types of chat responses."""
_CHAT_RESPONSE_REGISTRY: dict[ChatResponseType, type[BaseModel]] = {}


class ChatResponseBase(BaseModel):
"""Base class for all ChatResponse variants with auto-registration."""

type: ChatResponseType
content: (
str | Reference | StateUpdate | LiveUpdate | list[str] | Image | dict[str, MessageUsage] | ChunkedContent | None
)

def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)
type_ann = cls.model_fields["type"].annotation
origin = get_origin(type_ann)
value = get_args(type_ann)[0] if origin is Literal else getattr(cls, "type", None)

if value is None:
raise ValueError(f"Cannot determine ChatResponseType for {cls.__name__}")

_CHAT_RESPONSE_REGISTRY[value] = cls


class TextChatResponse(ChatResponseBase):
"""Represents text chat response"""

type: Literal[ChatResponseType.TEXT] = ChatResponseType.TEXT
content: str


class ReferenceChatResponse(ChatResponseBase):
"""Represents reference chat response"""

type: Literal[ChatResponseType.REFERENCE] = ChatResponseType.REFERENCE
content: Reference


class StateUpdateChatResponse(ChatResponseBase):
"""Represents state update chat response"""

type: Literal[ChatResponseType.STATE_UPDATE] = ChatResponseType.STATE_UPDATE
content: StateUpdate


class ConversationIdChatResponse(ChatResponseBase):
"""Represents conversation_id chat response"""

type: Literal[ChatResponseType.CONVERSATION_ID] = ChatResponseType.CONVERSATION_ID
content: str


class LiveUpdateChatResponse(ChatResponseBase):
"""Represents live update chat response"""

type: Literal[ChatResponseType.LIVE_UPDATE] = ChatResponseType.LIVE_UPDATE
content: LiveUpdate


class FollowupMessagesChatResponse(ChatResponseBase):
"""Represents followup messages chat response"""

type: Literal[ChatResponseType.FOLLOWUP_MESSAGES] = ChatResponseType.FOLLOWUP_MESSAGES
content: list[str]


class ImageChatResponse(ChatResponseBase):
"""Represents image chat response"""

type: Literal[ChatResponseType.IMAGE] = ChatResponseType.IMAGE
content: Image


class ClearMessageChatResponse(ChatResponseBase):
"""Represents clear message event"""

type: Literal[ChatResponseType.CLEAR_MESSAGE] = ChatResponseType.CLEAR_MESSAGE
content: None = None


class UsageChatResponse(ChatResponseBase):
"""Represents usage chat response"""

type: Literal[ChatResponseType.USAGE] = ChatResponseType.USAGE
content: dict[str, MessageUsage]


class MessageIdChatResponse(ChatResponseBase):
"""Represents message_id chat response"""

type: Literal[ChatResponseType.MESSAGE_ID] = ChatResponseType.MESSAGE_ID
content: str


class ChunkedContentChatResponse(ChatResponseBase):
"""Represents chunked_content event that contains chunked event of different type"""

type: Literal[ChatResponseType.CHUNKED_CONTENT] = ChatResponseType.CHUNKED_CONTENT
content: ChunkedContent


ChatResponseUnion = Annotated[
TextChatResponse
| ReferenceChatResponse
| StateUpdateChatResponse
| ConversationIdChatResponse
| LiveUpdateChatResponse
| FollowupMessagesChatResponse
| ImageChatResponse
| ClearMessageChatResponse
| UsageChatResponse
| MessageIdChatResponse
| ChunkedContentChatResponse,
Field(discriminator="type"),
]


class ChatResponse(RootModel[ChatResponseUnion]):
"""Container for different types of chat responses."""

root: ChatResponseUnion

@property
def content(self) -> object:
"""Returns content of a response, use dedicated `as_*` methods to get type hints."""
return self.root.content

@property
def type(self) -> ChatResponseType:
"""Returns type of the ChatResponse"""
return self.root.type

@overload
def __init__(
self,
type: Literal[ChatResponseType.TEXT],
content: str,
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.REFERENCE],
content: Reference,
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.STATE_UPDATE],
content: StateUpdate,
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.CONVERSATION_ID],
content: str,
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.LIVE_UPDATE],
content: LiveUpdate,
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.FOLLOWUP_MESSAGES],
content: list[str],
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.IMAGE],
content: Image,
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.CLEAR_MESSAGE],
content: None,
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.USAGE],
content: dict[str, MessageUsage],
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.MESSAGE_ID],
content: str,
) -> None: ...
@overload
def __init__(
self,
type: Literal[ChatResponseType.CHUNKED_CONTENT],
content: ChunkedContent,
) -> None: ...
def __init__(
self,
type: ChatResponseType,
content: Any,
) -> None:
"""
Backward-compatible constructor.
Allows creating a ChatResponse directly with:
ChatResponse(type=ChatResponseType.TEXT, content="hello")
"""
model_cls = _CHAT_RESPONSE_REGISTRY.get(type)
if model_cls is None:
raise ValueError(f"Unsupported ChatResponseType: {type}")

model_instance = model_cls(type=type, content=content)
super().__init__(root=cast(ChatResponseUnion, model_instance))

def as_text(self) -> str | None:
"""
Expand All @@ -151,7 +355,7 @@ def as_text(self) -> str | None:
if text := response.as_text():
print(f"Got text: {text}")
"""
return str(self.content) if self.type == ChatResponseType.TEXT else None
return self.root.content if isinstance(self.root, TextChatResponse) else None

def as_reference(self) -> Reference | None:
"""
Expand All @@ -161,7 +365,7 @@ def as_reference(self) -> Reference | None:
if ref := response.as_reference():
print(f"Got reference: {ref.title}")
"""
return cast(Reference, self.content) if self.type == ChatResponseType.REFERENCE else None
return self.root.content if isinstance(self.root, ReferenceChatResponse) else None

def as_state_update(self) -> StateUpdate | None:
"""
Expand All @@ -171,13 +375,13 @@ def as_state_update(self) -> StateUpdate | None:
if state_update := response.as_state_update():
state = verify_state(state_update)
"""
return cast(StateUpdate, self.content) if self.type == ChatResponseType.STATE_UPDATE else None
return self.root.content if isinstance(self.root, StateUpdateChatResponse) else None

def as_conversation_id(self) -> str | None:
"""
Return the content as ConversationID if this is a conversation id, else None.
"""
return cast(str, self.content) if self.type == ChatResponseType.CONVERSATION_ID else None
return self.root.content if isinstance(self.root, ConversationIdChatResponse) else None

def as_live_update(self) -> LiveUpdate | None:
"""
Expand All @@ -187,7 +391,7 @@ def as_live_update(self) -> LiveUpdate | None:
if live_update := response.as_live_update():
print(f"Got live update: {live_update.content.label}")
"""
return cast(LiveUpdate, self.content) if self.type == ChatResponseType.LIVE_UPDATE else None
return self.root.content if isinstance(self.root, LiveUpdateChatResponse) else None

def as_followup_messages(self) -> list[str] | None:
"""
Expand All @@ -197,25 +401,25 @@ def as_followup_messages(self) -> list[str] | None:
if followup_messages := response.as_followup_messages():
print(f"Got followup messages: {followup_messages}")
"""
return cast(list[str], self.content) if self.type == ChatResponseType.FOLLOWUP_MESSAGES else None
return self.root.content if isinstance(self.root, FollowupMessagesChatResponse) else None

def as_image(self) -> Image | None:
"""
Return the content as Image if this is an image response, else None.
"""
return cast(Image, self.content) if self.type == ChatResponseType.IMAGE else None
return self.root.content if isinstance(self.root, ImageChatResponse) else None

def as_clear_message(self) -> None:
"""
Return the content of clear_message response, which is None
"""
return cast(None, self.content)
return self.root.content if isinstance(self.root, ClearMessageChatResponse) else None

def as_usage(self) -> dict[str, MessageUsage] | None:
"""
Return the content as dict from model name to Usage if this is an usage response, else None
"""
return cast(dict[str, MessageUsage], self.content) if self.type == ChatResponseType.USAGE else None
return self.root.content if isinstance(self.root, UsageChatResponse) else None


class ChatMessageRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,21 @@

from pydantic import BaseModel

from ragbits.chat.interface.types import AuthType
from ragbits.chat.interface.types import (
AuthType,
ChatResponse,
ChunkedContentChatResponse,
ClearMessageChatResponse,
ConversationIdChatResponse,
FollowupMessagesChatResponse,
ImageChatResponse,
LiveUpdateChatResponse,
MessageIdChatResponse,
ReferenceChatResponse,
StateUpdateChatResponse,
TextChatResponse,
UsageChatResponse,
)


class RagbitsChatModelProvider:
Expand Down Expand Up @@ -93,6 +107,7 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]:
"FeedbackItem": FeedbackItem,
"Image": Image,
"MessageUsage": MessageUsage,
"StateUpdate": StateUpdate,
# Configuration models
"HeaderCustomization": HeaderCustomization,
"UICustomization": UICustomization,
Expand All @@ -114,6 +129,19 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]:
"LoginResponse": LoginResponse,
"LogoutRequest": LogoutRequest,
"User": User,
# Chat responses:
"TextChatResponse": TextChatResponse,
"ReferenceChatResponse": ReferenceChatResponse,
"MessageIdChatResponse": MessageIdChatResponse,
"ConversationIdChatResponse": ConversationIdChatResponse,
"StateUpdateChatResponse": StateUpdateChatResponse,
"LiveUpdateChatResponse": LiveUpdateChatResponse,
"FollowupMessagesChatResponse": FollowupMessagesChatResponse,
"ImageChatResponse": ImageChatResponse,
"ClearMessageChatResponse": ClearMessageChatResponse,
"UsageChatResponse": UsageChatResponse,
"ChunkedContentChatResponse": ChunkedContentChatResponse,
"ChatResponse": ChatResponse,
}

return self._models_cache
Expand Down Expand Up @@ -163,6 +191,18 @@ def get_categories(self) -> dict[str, list[str]]:
"FeedbackResponse",
"ConfigResponse",
"LoginResponse",
"TextChatResponse",
"ReferenceChatResponse",
"MessageIdChatResponse",
"ConversationIdChatResponse",
"StateUpdateChatResponse",
"LiveUpdateChatResponse",
"FollowupMessagesChatResponse",
"ImageChatResponse",
"ClearMessageChatResponse",
"UsageChatResponse",
"ChunkedContentChatResponse",
"ChatResponse",
],
"requests": [
"ChatRequest",
Expand Down
Loading
Loading