44from ellar .common .interfaces import IExecutionContext
55from ellar .common .logging import request_logger
66from ellar .common .serializer import Serializer , SerializerFilter
7+ from ellar .pydantic import field_validator
78
89from ..response_types import FileResponse , Response , StreamingResponse
910from .base import ResponseModel , ResponseModelField
1011
1112
12- class StreamingResponseModelInvalidContent (RuntimeError ):
13- pass
14-
15-
1613class ContentDispositionType (str , Enum ):
1714 inline = "inline"
1815 attachment = "attachment"
@@ -25,6 +22,19 @@ class FileResponseModelSchema(Serializer):
2522 method : t .Optional [str ] = None
2623 content_disposition_type : ContentDispositionType = ContentDispositionType .attachment
2724
25+ class StreamResponseModelSchema (Serializer ):
26+ media_type : t .Optional [str ] = None
27+ content : t .Any
28+
29+ @field_validator ('content' , mode = 'before' )
30+ def pre_validate_content (cls , value : t .Dict ) -> t .Any :
31+ if not isinstance (value , (t .AsyncGenerator , t .Generator )):
32+ raise ValueError (
33+ "Content must typing.AsyncIterable OR typing.Iterable"
34+ )
35+ return value
36+
37+
2838
2939class FileResponseModel (ResponseModel ):
3040 __slots__ = ("_file_init_schema" ,)
@@ -79,11 +89,23 @@ def serialize(
7989
8090class StreamingResponseModel (ResponseModel ):
8191 response_type = StreamingResponse
92+ file_schema_type = StreamResponseModelSchema
8293
8394 def get_model_field (self ) -> t .Optional [t .Union [ResponseModelField , t .Any ]]:
8495 # We don't want any schema for this.
8596 return None
8697
98+ def serialize (
99+ self ,
100+ response_obj : t .Any ,
101+ serializer_filter : t .Optional [SerializerFilter ] = None ,
102+ ) -> t .Union [t .List [t .Dict ], t .Dict , t .Any , StreamResponseModelSchema ]:
103+ if isinstance (response_obj , (t .AsyncGenerator , t .Generator )):
104+ response_obj = {"content" : response_obj , "media_type" : self .media_type }
105+
106+ value = self .file_schema_type .from_orm (response_obj )
107+ return value
108+
87109 def create_response (
88110 self , context : IExecutionContext , response_obj : t .Any , status_code : int
89111 ) -> Response :
@@ -94,12 +116,11 @@ def create_response(
94116 response_args , headers = self .get_context_response (
95117 context = context , status_code = status_code
96118 )
97- if not isinstance (response_obj , (t .AsyncGenerator , t .Generator )):
98- raise StreamingResponseModelInvalidContent (
99- "Content must typing.AsyncIterable OR typing.Iterable"
100- )
119+ data = self .serialize (response_obj )
101120
102121 response = self ._response_type (
103- ** response_args , headers = headers , content = response_obj
122+ ** response_args ,
123+ headers = headers , content = data .content ,
124+ media_type = data .media_type or self .media_type
104125 )
105126 return response
0 commit comments