Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from backend.chat.base import BaseChat
from backend.chat.custom.tool_calls import async_call_tools
from backend.chat.custom.utils import get_deployment
from backend.chat.enums import StreamEvent
from backend.chat.enums import FinishReason, StreamEvent
from backend.config import Settings
from backend.config.tools import get_available_tools
from backend.database_models.file import File
Expand Down Expand Up @@ -87,7 +87,7 @@ async def chat(
)
yield {
"event_type": StreamEvent.STREAM_END,
"finish_reason": "ERROR",
"finish_reason": FinishReason.ERROR,
"error": str(e),
"status_code": 500,
}
Expand Down
9 changes: 9 additions & 0 deletions src/backend/chat/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@ class StreamEvent(StrEnum):
NON_STREAMED_CHAT_RESPONSE = "non-streamed-chat-response"
TOOL_CALLS_GENERATION = "tool-calls-generation"
TOOL_CALLS_CHUNK = "tool-calls-chunk"


class FinishReason(StrEnum):
"""
Reasons why the model finished the request.
"""
ERROR = "ERROR"
COMPLETE = "COMPLETE"
MAX_TOKENS = "MAX_TOKENS"
6 changes: 3 additions & 3 deletions src/backend/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic import BaseModel, Field

from backend.chat.enums import StreamEvent
from backend.chat.enums import FinishReason, StreamEvent
from backend.schemas.citation import Citation
from backend.schemas.document import Document
from backend.schemas.search_query import SearchQuery
Expand Down Expand Up @@ -288,7 +288,7 @@ class StreamEnd(ChatResponse):
title="Tool Calls",
description="List of tool calls generated for custom tools",
)
finish_reason: Optional[str] = Field(
finish_reason: Optional[FinishReason] = Field(
None,
title="Finish Reason",
description="Reson why the model finished the request",
Expand Down Expand Up @@ -322,7 +322,7 @@ class NonStreamedChatResponse(ChatResponse):
title="Chat History",
description="A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message.",
)
finish_reason: str = Field(
finish_reason: FinishReason = Field(
...,
title="Finish Reason",
description="Reason the chat stream ended",
Expand Down
3 changes: 2 additions & 1 deletion src/backend/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from backend.chat.enums import FinishReason
from backend.database_models import get_session
from backend.database_models.agent import Agent
from backend.database_models.deployment import Deployment
Expand Down Expand Up @@ -204,7 +205,7 @@ def mock_event_stream(inject_events: list[dict]) -> list[dict]:
"search_results": [],
"search_queries": [],
},
"finish_reason": "COMPLETE",
"finish_reason": FinishReason.COMPLETE,
}
])
return events
Expand Down
3 changes: 2 additions & 1 deletion src/backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import text

from backend.chat.enums import FinishReason
from backend.database_models import get_session
from backend.database_models.base import CustomFilterQuery
from backend.database_models.deployment import Deployment
Expand Down Expand Up @@ -252,7 +253,7 @@ def mock_event_stream(inject_events: list[dict]) -> list[dict]:
"search_results": [],
"search_queries": [],
},
"finish_reason": "COMPLETE",
"finish_reason": FinishReason.COMPLETE,
}
])
return events
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from cohere.types import StreamedChatResponse

from backend.chat.enums import FinishReason
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.tests.unit.model_deployments.mock_deployments.mock_base import (
Expand Down Expand Up @@ -51,7 +52,7 @@ async def invoke_chat(
"is_search_required": None,
"search_queries": None,
"search_results": None,
"finish_reason": "MAX_TOKENS",
"finish_reason": FinishReason.MAX_TOKENS,
"tool_calls": None,
"chat_history": [
{"role": "USER", "message": "Hello"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from cohere.types import StreamedChatResponse

from backend.chat.enums import FinishReason
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.tests.unit.model_deployments.mock_deployments.mock_base import (
Expand Down Expand Up @@ -48,7 +49,7 @@ async def invoke_chat(
"is_search_required": None,
"search_queries": None,
"search_results": None,
"finish_reason": "MAX_TOKENS",
"finish_reason": FinishReason.MAX_TOKENS,
"tool_calls": None,
"chat_history": [
{"role": "USER", "message": "Hello"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from cohere.types import StreamedChatResponse

from backend.chat.enums import FinishReason
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.services.conversation import SEARCH_RELEVANCE_THRESHOLD
Expand Down Expand Up @@ -55,7 +56,7 @@ def invoke_chat(
"is_search_required": None,
"search_queries": None,
"search_results": None,
"finish_reason": "MAX_TOKENS",
"finish_reason": FinishReason.MAX_TOKENS,
"tool_calls": None,
"chat_history": [
{"role": "USER", "message": "Hello"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from cohere.types import StreamedChatResponse

from backend.chat.enums import FinishReason
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.tests.unit.model_deployments.mock_deployments.mock_base import (
Expand Down Expand Up @@ -48,7 +49,7 @@ async def invoke_chat(
"is_search_required": None,
"search_queries": None,
"search_results": None,
"finish_reason": "MAX_TOKENS",
"finish_reason": FinishReason.MAX_TOKENS,
"tool_calls": None,
"chat_history": [
{"role": "USER", "message": "Hello"},
Expand Down
4 changes: 2 additions & 2 deletions src/backend/tests/unit/routers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.chat.enums import StreamEvent
from backend.chat.enums import FinishReason, StreamEvent
from backend.database_models.conversation import Conversation
from backend.database_models.message import Message, MessageAgent
from backend.database_models.user import User
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def validate_stream_end_event(
assert is_valid_uuid(data["response_id"])
assert is_valid_uuid(data["conversation_id"])
assert is_valid_uuid(data["generation_id"])
assert data["finish_reason"] == "COMPLETE" or data["finish_reason"] == "MAX_TOKENS"
assert data["finish_reason"] == FinishReason.COMPLETE or data["finish_reason"] == FinishReason.MAX_TOKENS

return data["conversation_id"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// @todo: import from generated types when available
// @todo: Once backend FinishReason enum is merged, run `make generate-client-web`
// and remove this enum in favor of the generated type from backend
export enum FinishReason {
ERROR = 'ERROR',
COMPLETE = 'COMPLETE',
Expand Down