Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions api/controllers/web/saved_message.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound

from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import uuid_value
from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService


class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)


class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty


register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)


@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
@web_ns.doc("Get Saved Messages")
Expand Down Expand Up @@ -42,14 +54,10 @@ def get(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()

parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
raw_args = request.args.to_dict()
query = SavedMessageListQuery.model_validate(raw_args)

pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
Expand Down Expand Up @@ -79,11 +87,10 @@ def post(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()

parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args()
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})

try:
SavedMessageService.save(app_model, end_user, args["message_id"])
SavedMessageService.save(app_model, end_user, payload.message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

Expand Down
Loading