1111from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
1212from pydantic import BaseModel , Field , create_model
1313from starlette .responses import RedirectResponse
14- from unstructured_ingest .data_types .file_data import file_data_from_dict
14+ from unstructured_ingest .data_types .file_data import BatchFileData , FileData , file_data_from_dict
1515from uvicorn .config import LOG_LEVELS
1616from uvicorn .importer import import_from_string
1717
3131 schema_to_base_model ,
3232)
3333
34+ FileDataType = Union [FileData , BatchFileData ]
35+
3436
3537class EtlApiException (Exception ):
3638 pass
@@ -137,7 +139,8 @@ def _wrap_in_fastapi(
137139 class InvokeResponse (BaseModel ):
138140 usage : list [UsageData ]
139141 status_code : int
140- filedata_meta : Optional [filedata_meta_model ]
142+ file_data : Optional [FileDataType ] = None
143+ filedata_meta : Optional [filedata_meta_model ] = None
141144 status_code_text : Optional [str ] = None
142145 output : Optional [response_type ] = None
143146 message_channels : MessageChannels = Field (default_factory = MessageChannels )
@@ -177,6 +180,7 @@ async def _stream_response():
177180 ),
178181 status_code = status .HTTP_200_OK ,
179182 output = output ,
183+ file_data = request_dict .get ("file_data" , None ),
180184 ).model_dump_json ()
181185 + "\n "
182186 )
@@ -202,6 +206,7 @@ async def _stream_response():
202206 filedata_meta = filedata_meta_model .model_validate (filedata_meta .model_dump ()),
203207 status_code = status .HTTP_200_OK ,
204208 output = output ,
209+ file_data = request_dict .get ("file_data" , None ),
205210 )
206211 except UnrecoverableException as ex :
207212 logger .info ("Unrecoverable error occurred during plugin invocation" )
@@ -211,6 +216,7 @@ async def _stream_response():
211216 status_code = 512 ,
212217 status_code_text = ex .message ,
213218 filedata_meta = filedata_meta_model .model_validate (filedata_meta .model_dump ()),
219+ file_data = request_dict .get ("file_data" , None ),
214220 )
215221 except Exception as invoke_error :
216222 logger .error (f"failed to invoke plugin: { invoke_error } " , exc_info = True )
@@ -221,6 +227,7 @@ async def _stream_response():
221227 filedata_meta = filedata_meta_model .model_validate (filedata_meta .model_dump ()),
222228 status_code = http_error .status_code ,
223229 status_code_text = f"[{ invoke_error .__class__ .__name__ } ] { invoke_error } " ,
230+ file_data = request_dict .get ("file_data" , None ),
224231 )
225232
226233 if input_schema_model .model_fields :
0 commit comments