diff --git a/packages/ragbits-chat/CHANGELOG.md b/packages/ragbits-chat/CHANGELOG.md index aebabd8c1..aa587b3d9 100644 --- a/packages/ragbits-chat/CHANGELOG.md +++ b/packages/ragbits-chat/CHANGELOG.md @@ -2,6 +2,8 @@ ## Unreleased +- Move ChatResponse to union of types (#809) + ## 1.3.0 (2025-09-11) ### Changed diff --git a/packages/ragbits-chat/src/ragbits/chat/interface/types.py b/packages/ragbits-chat/src/ragbits/chat/interface/types.py index 920003651..118bb4b88 100644 --- a/packages/ragbits-chat/src/ragbits/chat/interface/types.py +++ b/packages/ragbits-chat/src/ragbits/chat/interface/types.py @@ -1,7 +1,7 @@ from enum import Enum -from typing import Any, cast +from typing import Annotated, Any, Literal, cast, get_args, get_origin, overload -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, RootModel from ragbits.chat.auth.types import User from ragbits.chat.interface.forms import UserSettings @@ -135,13 +135,217 @@ class ChatContext(BaseModel): model_config = ConfigDict(extra="allow") -class ChatResponse(BaseModel): - """Container for different types of chat responses.""" +_CHAT_RESPONSE_REGISTRY: dict[ChatResponseType, type[BaseModel]] = {} + + +class ChatResponseBase(BaseModel): + """Base class for all ChatResponse variants with auto-registration.""" type: ChatResponseType - content: ( - str | Reference | StateUpdate | LiveUpdate | list[str] | Image | dict[str, MessageUsage] | ChunkedContent | None - ) + + def __init_subclass__(cls, **kwargs: Any): + super().__init_subclass__(**kwargs) + type_ann = cls.model_fields["type"].annotation + origin = get_origin(type_ann) + value = get_args(type_ann)[0] if origin is Literal else getattr(cls, "type", None) + + if value is None: + raise ValueError(f"Cannot determine ChatResponseType for {cls.__name__}") + + _CHAT_RESPONSE_REGISTRY[value] = cls + + +class TextChatResponse(ChatResponseBase): + """Represents text chat response""" + + type: Literal[ChatResponseType.TEXT] = ChatResponseType.TEXT + content: str + + +class ReferenceChatResponse(ChatResponseBase): + """Represents reference chat response""" + + type: Literal[ChatResponseType.REFERENCE] = ChatResponseType.REFERENCE + content: Reference + + +class StateUpdateChatResponse(ChatResponseBase): + """Represents state update chat response""" + + type: Literal[ChatResponseType.STATE_UPDATE] = ChatResponseType.STATE_UPDATE + content: StateUpdate + + +class ConversationIdChatResponse(ChatResponseBase): + """Represents conversation_id chat response""" + + type: Literal[ChatResponseType.CONVERSATION_ID] = ChatResponseType.CONVERSATION_ID + content: str + + +class LiveUpdateChatResponse(ChatResponseBase): + """Represents live update chat response""" + + type: Literal[ChatResponseType.LIVE_UPDATE] = ChatResponseType.LIVE_UPDATE + content: LiveUpdate + + +class FollowupMessagesChatResponse(ChatResponseBase): + """Represents followup messages chat response""" + + type: Literal[ChatResponseType.FOLLOWUP_MESSAGES] = ChatResponseType.FOLLOWUP_MESSAGES + content: list[str] + + +class ImageChatResponse(ChatResponseBase): + """Represents image chat response""" + + type: Literal[ChatResponseType.IMAGE] = ChatResponseType.IMAGE + content: Image + + +class ClearMessageChatResponse(ChatResponseBase): + """Represents clear message event""" + + type: Literal[ChatResponseType.CLEAR_MESSAGE] = ChatResponseType.CLEAR_MESSAGE + content: None = None + + +class UsageChatResponse(ChatResponseBase): + """Represents usage chat response""" + + type: Literal[ChatResponseType.USAGE] = ChatResponseType.USAGE + content: dict[str, MessageUsage] + + +class MessageIdChatResponse(ChatResponseBase): + """Represents message_id chat response""" + + type: Literal[ChatResponseType.MESSAGE_ID] = ChatResponseType.MESSAGE_ID + content: str + + +class ChunkedContentChatResponse(ChatResponseBase): + """Represents chunked_content event that contains chunked event of different type""" + + type: Literal[ChatResponseType.CHUNKED_CONTENT] = ChatResponseType.CHUNKED_CONTENT + content: ChunkedContent + + +ChatResponseUnion = Annotated[ + TextChatResponse + | ReferenceChatResponse + | StateUpdateChatResponse + | ConversationIdChatResponse + | LiveUpdateChatResponse + | FollowupMessagesChatResponse + | ImageChatResponse + | ClearMessageChatResponse + | UsageChatResponse + | MessageIdChatResponse + | ChunkedContentChatResponse, + Field(discriminator="type"), +] + + +class ChatResponse(RootModel[ChatResponseUnion]): + """Container for different types of chat responses.""" + + root: ChatResponseUnion + + @property + def content(self) -> object: + """Returns content of a response, use dedicated `as_*` methods to get type hints.""" + return self.root.content + + @property + def type(self) -> ChatResponseType: + """Returns type of the ChatResponse""" + return self.root.type + + @overload + def __init__( + self, + type: Literal[ChatResponseType.TEXT], + content: str, + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.REFERENCE], + content: Reference, + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.STATE_UPDATE], + content: StateUpdate, + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.CONVERSATION_ID], + content: str, + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.LIVE_UPDATE], + content: LiveUpdate, + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.FOLLOWUP_MESSAGES], + content: list[str], + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.IMAGE], + content: Image, + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.CLEAR_MESSAGE], + content: None, + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.USAGE], + content: dict[str, MessageUsage], + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.MESSAGE_ID], + content: str, + ) -> None: ... + @overload + def __init__( + self, + type: Literal[ChatResponseType.CHUNKED_CONTENT], + content: ChunkedContent, + ) -> None: ... + def __init__( + self, + type: ChatResponseType, + content: Any, + ) -> None: + """ + Backward-compatible constructor. + + Allows creating a ChatResponse directly with: + ChatResponse(type=ChatResponseType.TEXT, content="hello") + """ + model_cls = _CHAT_RESPONSE_REGISTRY.get(type) + if model_cls is None: + raise ValueError(f"Unsupported ChatResponseType: {type}") + + model_instance = model_cls(type=type, content=content) + super().__init__(root=cast(ChatResponseUnion, model_instance)) def as_text(self) -> str | None: """ @@ -151,7 +355,7 @@ def as_text(self) -> str | None: if text := response.as_text(): print(f"Got text: {text}") """ - return str(self.content) if self.type == ChatResponseType.TEXT else None + return self.root.content if isinstance(self.root, TextChatResponse) else None def as_reference(self) -> Reference | None: """ @@ -161,7 +365,7 @@ def as_reference(self) -> Reference | None: if ref := response.as_reference(): print(f"Got reference: {ref.title}") """ - return cast(Reference, self.content) if self.type == ChatResponseType.REFERENCE else None + return self.root.content if isinstance(self.root, ReferenceChatResponse) else None def as_state_update(self) -> StateUpdate | None: """ @@ -171,13 +375,13 @@ def as_state_update(self) -> StateUpdate | None: if state_update := response.as_state_update(): state = verify_state(state_update) """ - return cast(StateUpdate, self.content) if self.type == ChatResponseType.STATE_UPDATE else None + return self.root.content if isinstance(self.root, StateUpdateChatResponse) else None def as_conversation_id(self) -> str | None: """ Return the content as ConversationID if this is a conversation id, else None. """ - return cast(str, self.content) if self.type == ChatResponseType.CONVERSATION_ID else None + return self.root.content if isinstance(self.root, ConversationIdChatResponse) else None def as_live_update(self) -> LiveUpdate | None: """ @@ -187,7 +391,7 @@ def as_live_update(self) -> LiveUpdate | None: if live_update := response.as_live_update(): print(f"Got live update: {live_update.content.label}") """ - return cast(LiveUpdate, self.content) if self.type == ChatResponseType.LIVE_UPDATE else None + return self.root.content if isinstance(self.root, LiveUpdateChatResponse) else None def as_followup_messages(self) -> list[str] | None: """ @@ -197,25 +401,25 @@ def as_followup_messages(self) -> list[str] | None: if followup_messages := response.as_followup_messages(): print(f"Got followup messages: {followup_messages}") """ - return cast(list[str], self.content) if self.type == ChatResponseType.FOLLOWUP_MESSAGES else None + return self.root.content if isinstance(self.root, FollowupMessagesChatResponse) else None def as_image(self) -> Image | None: """ Return the content as Image if this is an image response, else None. """ - return cast(Image, self.content) if self.type == ChatResponseType.IMAGE else None + return self.root.content if isinstance(self.root, ImageChatResponse) else None def as_clear_message(self) -> None: """ Return the content of clear_message response, which is None """ - return cast(None, self.content) + return self.root.content if isinstance(self.root, ClearMessageChatResponse) else None def as_usage(self) -> dict[str, MessageUsage] | None: """ Return the content as dict from model name to Usage if this is an usage response, else None """ - return cast(dict[str, MessageUsage], self.content) if self.type == ChatResponseType.USAGE else None + return self.root.content if isinstance(self.root, UsageChatResponse) else None class ChatMessageRequest(BaseModel): diff --git a/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py b/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py index 6b1191c74..fac659818 100644 --- a/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py +++ b/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py @@ -10,7 +10,21 @@ from pydantic import BaseModel -from ragbits.chat.interface.types import AuthType +from ragbits.chat.interface.types import ( + AuthType, + ChatResponse, + ChunkedContentChatResponse, + ClearMessageChatResponse, + ConversationIdChatResponse, + FollowupMessagesChatResponse, + ImageChatResponse, + LiveUpdateChatResponse, + MessageIdChatResponse, + ReferenceChatResponse, + StateUpdateChatResponse, + TextChatResponse, + UsageChatResponse, +) class RagbitsChatModelProvider: @@ -93,6 +107,7 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]: "FeedbackItem": FeedbackItem, "Image": Image, "MessageUsage": MessageUsage, + "StateUpdate": StateUpdate, # Configuration models "HeaderCustomization": HeaderCustomization, "UICustomization": UICustomization, @@ -114,6 +129,19 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]: "LoginResponse": LoginResponse, "LogoutRequest": LogoutRequest, "User": User, + # Chat responses: + "TextChatResponse": TextChatResponse, + "ReferenceChatResponse": ReferenceChatResponse, + "MessageIdChatResponse": MessageIdChatResponse, + "ConversationIdChatResponse": ConversationIdChatResponse, + "StateUpdateChatResponse": StateUpdateChatResponse, + "LiveUpdateChatResponse": LiveUpdateChatResponse, + "FollowupMessagesChatResponse": FollowupMessagesChatResponse, + "ImageChatResponse": ImageChatResponse, + "ClearMessageChatResponse": ClearMessageChatResponse, + "UsageChatResponse": UsageChatResponse, + "ChunkedContentChatResponse": ChunkedContentChatResponse, + "ChatResponse": ChatResponse, } return self._models_cache @@ -163,6 +191,18 @@ def get_categories(self) -> dict[str, list[str]]: "FeedbackResponse", "ConfigResponse", "LoginResponse", + "TextChatResponse", + "ReferenceChatResponse", + "MessageIdChatResponse", + "ConversationIdChatResponse", + "StateUpdateChatResponse", + "LiveUpdateChatResponse", + "FollowupMessagesChatResponse", + "ImageChatResponse", + "ClearMessageChatResponse", + "UsageChatResponse", + "ChunkedContentChatResponse", + "ChatResponse", ], "requests": [ "ChatRequest", diff --git a/scripts/generate_typescript_from_json_schema.py b/scripts/generate_typescript_from_json_schema.py index 903e85e12..1174ad26a 100644 --- a/scripts/generate_typescript_from_json_schema.py +++ b/scripts/generate_typescript_from_json_schema.py @@ -181,50 +181,6 @@ def _generate_typescript_with_node(schema: dict[str, any], type_name: str) -> st raise RuntimeError("json-schema-to-typescript is required but not installed") from None -def _generate_chat_response_union_type() -> str: - """Generate ChatResponse union type and specific response interfaces.""" - lines = [] - - lines.append("/**") - lines.append(" * Specific chat response types") - lines.append(" */") - - # Generate specific response interfaces - response_interfaces = [ - ("TextChatResponse", "text", "string"), - ("ReferenceChatResponse", "reference", "Reference"), - ("MessageIdChatResponse", "message_id", "string"), - ("ConversationIdChatResponse", "conversation_id", "string"), - ("StateUpdateChatResponse", "state_update", "ServerState"), - ("LiveUpdateChatResponse", "live_update", "LiveUpdate"), - ("FollowupMessagesChatResponse", "followup_messages", "string[]"), - ("ImageChatResponse", "image", "Image"), - ("ClearMessageResponse", "clear_message", "never"), - ("MessageUsageChatResponse", "usage", "Record"), - ] - - internal_response_interfaces = [ - ("ChunkedChatResponse", "chunked_content", "ChunkedContent"), - ] - - for interface_name, response_type, content_type in [*response_interfaces, *internal_response_interfaces]: - lines.append(f"export interface {interface_name} {{") - lines.append(f" type: '{response_type}'") - lines.append(f" content: {content_type}") - lines.append("}") - lines.append("") - - lines.append("/**") - lines.append(" * Typed chat response union") - lines.append(" */") - lines.append("export type ChatResponse =") - - for interface_name, _, _ in response_interfaces: - lines.append(f" | {interface_name}") - - return "\n".join(lines) - - def _generate_ts_enum_object(enum_name: str, enum_values: list[str]) -> str: lines = [] lines.append("/**") @@ -306,10 +262,6 @@ def main() -> None: lines.append(_generate_typescript_with_node(schemas[name], name)) lines.append("") - # Generate ChatResponse union type - lines.append(_generate_chat_response_union_type()) - lines.append("") - # Write to output file output_file = Path("typescript/@ragbits/api-client/src/autogen.types.ts") output_file.parent.mkdir(parents=True, exist_ok=True) diff --git a/typescript/@ragbits/api-client/package.json b/typescript/@ragbits/api-client/package.json index 5d5a27567..f10063cf4 100644 --- a/typescript/@ragbits/api-client/package.json +++ b/typescript/@ragbits/api-client/package.json @@ -5,7 +5,7 @@ "repository": { "type": "git", "url": "https://github.com/deepsense-ai/ragbits" - }, + }, "main": "dist/index.cjs", "module": "dist/index.js", "types": "dist/index.d.ts", diff --git a/typescript/@ragbits/api-client/src/autogen.types.ts b/typescript/@ragbits/api-client/src/autogen.types.ts index 5ac4f55af..915ff577d 100644 --- a/typescript/@ragbits/api-client/src/autogen.types.ts +++ b/typescript/@ragbits/api-client/src/autogen.types.ts @@ -170,6 +170,16 @@ export interface MessageUsage { total_tokens: number } +/** + * Represents an update to conversation state. + */ +export interface StateUpdate { + state: { + [k: string]: unknown + } + signature: string +} + /** * Customization for the header section of the UI. */ @@ -413,74 +423,107 @@ export interface User { } /** - * Specific chat response types + * Represents text chat response */ export interface TextChatResponse { type: 'text' content: string } +/** + * Represents reference chat response + */ export interface ReferenceChatResponse { type: 'reference' content: Reference } +/** + * Represents message_id chat response + */ export interface MessageIdChatResponse { type: 'message_id' content: string } +/** + * Represents conversation_id chat response + */ export interface ConversationIdChatResponse { type: 'conversation_id' content: string } +/** + * Represents state update chat response + */ export interface StateUpdateChatResponse { type: 'state_update' - content: ServerState + content: StateUpdate } +/** + * Represents live update chat response + */ export interface LiveUpdateChatResponse { type: 'live_update' content: LiveUpdate } +/** + * Represents followup messages chat response + */ export interface FollowupMessagesChatResponse { type: 'followup_messages' content: string[] } +/** + * Represents image chat response + */ export interface ImageChatResponse { type: 'image' content: Image } -export interface ClearMessageResponse { +/** + * Represents clear message event + */ +export interface ClearMessageChatResponse { type: 'clear_message' - content: never + content: null } -export interface MessageUsageChatResponse { +/** + * Represents usage chat response + */ +export interface UsageChatResponse { type: 'usage' - content: Record + content: { + [k: string]: MessageUsage + } } -export interface ChunkedChatResponse { +/** + * Represents chunked_content event that contains chunked event of different type + */ +export interface ChunkedContentChatResponse { type: 'chunked_content' content: ChunkedContent } /** - * Typed chat response union + * Container for different types of chat responses. */ export type ChatResponse = | TextChatResponse | ReferenceChatResponse - | MessageIdChatResponse - | ConversationIdChatResponse | StateUpdateChatResponse + | ConversationIdChatResponse | LiveUpdateChatResponse | FollowupMessagesChatResponse | ImageChatResponse - | ClearMessageResponse - | MessageUsageChatResponse + | ClearMessageChatResponse + | UsageChatResponse + | MessageIdChatResponse + | ChunkedContentChatResponse diff --git a/typescript/@ragbits/api-client/src/index.ts b/typescript/@ragbits/api-client/src/index.ts index 57851fa8e..98dff3eb8 100644 --- a/typescript/@ragbits/api-client/src/index.ts +++ b/typescript/@ragbits/api-client/src/index.ts @@ -1,4 +1,8 @@ -import { ChatResponseType, ChunkedChatResponse, Image } from './autogen.types' +import { + ChatResponseType, + ChunkedContentChatResponse, + Image, +} from './autogen.types' import type { ClientConfig, StreamCallbacks, @@ -8,6 +12,7 @@ import type { RequestOptions, BaseStreamingEndpoints, EndpointRequest, + ChatResponse, } from './types' /** @@ -297,7 +302,7 @@ export class RagbitsClient { data: T, callbacks: StreamCallbacks ): Promise { - const response = data as ChunkedChatResponse + const response = data as ChunkedContentChatResponse const content = response.content const { @@ -364,3 +369,5 @@ export class RagbitsClient { // Re-export types export * from './types' export * from './autogen.types' +// Re-export the redefined ChatResponse +export type { ChatResponse } diff --git a/typescript/@ragbits/api-client/src/types.ts b/typescript/@ragbits/api-client/src/types.ts index 9e4da36f3..7da85435a 100644 --- a/typescript/@ragbits/api-client/src/types.ts +++ b/typescript/@ragbits/api-client/src/types.ts @@ -3,12 +3,16 @@ import { FeedbackRequest, FeedbackResponse, ChatRequest, - ChatResponse, + ChatResponse as _ChatResponse, LogoutRequest, LoginRequest, LoginResponse, + ChunkedContentChatResponse, } from './autogen.types' +// Redefine ChatResponse to exclude "internal" events handled by the library +export type ChatResponse = Exclude<_ChatResponse, ChunkedContentChatResponse> + export interface GenericResponse { success: boolean } diff --git a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts index 2811625a3..85ac18334 100644 --- a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts +++ b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts @@ -1,12 +1,12 @@ import { - ClearMessageResponse, + ClearMessageChatResponse, ImageChatResponse, LiveUpdateChatResponse, LiveUpdateType, MessageIdChatResponse, - MessageUsageChatResponse, ReferenceChatResponse, TextChatResponse, + UsageChatResponse, } from "@ragbits/api-client-react"; import { PrimaryHandler } from "./eventHandlerRegistry"; import { produce } from "immer"; @@ -76,7 +76,7 @@ export const handleImage: PrimaryHandler = ( }); }; -export const handleClearMessage: PrimaryHandler = ( +export const handleClearMessage: PrimaryHandler = ( _, draft, ctx, @@ -89,7 +89,7 @@ export const handleClearMessage: PrimaryHandler = ( }; }; -export const handleUsage: PrimaryHandler = ( +export const handleUsage: PrimaryHandler = ( response, draft, ctx,