44from ellar .common .interfaces import IExecutionContext
55from ellar .common .logging import request_logger
66from ellar .common .serializer import Serializer , SerializerFilter
7+ from ellar .pydantic import field_validator
78
89from ..response_types import FileResponse , Response , StreamingResponse
910from .base import ResponseModel , ResponseModelField
1011
1112
12- class StreamingResponseModelInvalidContent (RuntimeError ):
13- pass
14-
15-
1613class ContentDispositionType (str , Enum ):
1714 inline = "inline"
1815 attachment = "attachment"
@@ -26,6 +23,17 @@ class FileResponseModelSchema(Serializer):
2623 content_disposition_type : ContentDispositionType = ContentDispositionType .attachment
2724
2825
26+ class StreamResponseModelSchema (Serializer ):
27+ media_type : t .Optional [str ] = None
28+ content : t .Any
29+
30+ @field_validator ("content" , mode = "before" )
31+ def pre_validate_content (cls , value : t .Dict ) -> t .Any :
32+ if not isinstance (value , (t .AsyncGenerator , t .Generator )):
33+ raise ValueError ("Content must typing.AsyncIterable OR typing.Iterable" )
34+ return value
35+
36+
2937class FileResponseModel (ResponseModel ):
3038 __slots__ = ("_file_init_schema" ,)
3139
@@ -79,11 +87,23 @@ def serialize(
7987
8088class StreamingResponseModel (ResponseModel ):
8189 response_type = StreamingResponse
90+ file_schema_type = StreamResponseModelSchema
8291
8392 def get_model_field (self ) -> t .Optional [t .Union [ResponseModelField , t .Any ]]:
8493 # We don't want any schema for this.
8594 return None
8695
96+ def serialize (
97+ self ,
98+ response_obj : t .Any ,
99+ serializer_filter : t .Optional [SerializerFilter ] = None ,
100+ ) -> t .Union [t .List [t .Dict ], t .Dict , t .Any , StreamResponseModelSchema ]:
101+ if isinstance (response_obj , (t .AsyncGenerator , t .Generator )):
102+ response_obj = {"content" : response_obj , "media_type" : self .media_type }
103+
104+ value = self .file_schema_type .from_orm (response_obj )
105+ return value
106+
87107 def create_response (
88108 self , context : IExecutionContext , response_obj : t .Any , status_code : int
89109 ) -> Response :
@@ -94,12 +114,12 @@ def create_response(
94114 response_args , headers = self .get_context_response (
95115 context = context , status_code = status_code
96116 )
97- if not isinstance (response_obj , (t .AsyncGenerator , t .Generator )):
98- raise StreamingResponseModelInvalidContent (
99- "Content must typing.AsyncIterable OR typing.Iterable"
100- )
117+ data = t .cast (StreamResponseModelSchema , self .serialize (response_obj ))
101118
102119 response = self ._response_type (
103- ** response_args , headers = headers , content = response_obj
120+ ** response_args ,
121+ headers = headers ,
122+ content = data .content ,
123+ media_type = data .media_type or self .media_type ,
104124 )
105125 return response
0 commit comments