44import json
55import logging
66from functools import partial
7- from typing import Any , Callable , Optional
7+ from typing import Any , Callable , Optional , Union
88
99from fastapi import FastAPI , status
1010from fastapi .responses import StreamingResponse
1111from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
12- from pydantic import BaseModel
12+ from pydantic import BaseModel , Field , create_model
1313from starlette .responses import RedirectResponse
14+ from unstructured_ingest .v2 .interfaces import FileData
1415from uvicorn .config import LOG_LEVELS
1516from uvicorn .importer import import_from_string
1617
2324 get_schema_dict ,
2425 map_inputs ,
2526)
26- from unstructured_platform_plugins .exceptions import UnrecoverableException
27+ from unstructured_platform_plugins .schema import FileDataMeta , NewRecord , UsageData
2728from unstructured_platform_plugins .schema .json_schema import (
2829 schema_to_base_model ,
2930)
30- from unstructured_platform_plugins .schema .usage import UsageData
3131
3232logger = logging .getLogger ("uvicorn.error" )
3333
@@ -68,6 +68,30 @@ def check_precheck_func(precheck_func: Callable):
6868 raise ValueError (f"no output should exist for precheck function, found: { outputs } " )
6969
7070
71+ def is_optional (t : Any ) -> bool :
72+ return (
73+ hasattr (t , "__origin__" )
74+ and t .__origin__ is Union
75+ and hasattr (t , "__args__" )
76+ and type (None ) in t .__args__
77+ )
78+
79+
80+ def update_filedata_model (new_type ) -> BaseModel :
81+ field_info = NewRecord .model_fields ["contents" ]
82+ if is_optional (new_type ):
83+ field_info .default = None
84+ new_record_model = create_model (
85+ NewRecord .__name__ , contents = (new_type , field_info ), __base__ = NewRecord
86+ )
87+ new_filedata_model = create_model (
88+ FileDataMeta .__name__ ,
89+ new_records = (list [new_record_model ], Field (default_factory = list )),
90+ __base__ = FileDataMeta ,
91+ )
92+ return new_filedata_model
93+
94+
7195def wrap_in_fastapi (
7296 func : Callable ,
7397 plugin_id : str ,
@@ -81,14 +105,16 @@ def wrap_in_fastapi(
81105 fastapi_app = FastAPI ()
82106
83107 response_type = get_output_sig (func )
108+ filedata_meta_model = update_filedata_model (response_type )
84109
85110 class InvokeResponse (BaseModel ):
86111 usage : list [UsageData ]
87112 status_code : int
113+ filedata_meta : filedata_meta_model
88114 status_code_text : Optional [str ] = None
89115 output : Optional [response_type ] = None
90116
91- input_schema = get_input_schema (func , omit = ["usage" ])
117+ input_schema = get_input_schema (func , omit = ["usage" , "filedata_meta" ])
92118 input_schema_model = schema_to_base_model (input_schema )
93119
94120 logging .getLogger ("etl_uvicorn.fastapi" )
@@ -97,36 +123,43 @@ class InvokeResponse(BaseModel):
97123
98124 async def wrap_fn (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> ResponseType :
99125 usage : list [UsageData ] = []
126+ filedata_meta = FileDataMeta ()
100127 request_dict = kwargs if kwargs else {}
101128 if "usage" in inspect .signature (func ).parameters :
102129 request_dict ["usage" ] = usage
103130 else :
104- logger .debug ("usage data not an expected parameter, omitting" )
131+ logger .warning ("usage data not an expected parameter, omitting" )
132+ if "filedata_meta" in inspect .signature (func ).parameters :
133+ request_dict ["filedata_meta" ] = filedata_meta
105134 try :
106135 if inspect .isasyncgenfunction (func ):
107136 # Stream response if function is an async generator
108137
109138 async def _stream_response ():
110139 async for output in func (** (request_dict or {})):
111140 yield InvokeResponse (
112- usage = usage , status_code = status .HTTP_200_OK , output = output
141+ usage = usage ,
142+ filedata_meta = filedata_meta_model .model_validate (
143+ filedata_meta .model_dump ()
144+ ),
145+ status_code = status .HTTP_200_OK ,
146+ output = output ,
113147 ).model_dump_json () + "\n "
114148
115149 return StreamingResponse (_stream_response (), media_type = "application/x-ndjson" )
116150 else :
117- try :
118- output = await invoke_func (func = func , kwargs = request_dict )
119- return InvokeResponse (
120- usage = usage , status_code = status .HTTP_200_OK , output = output
121- )
122- except UnrecoverableException as ex :
123- # Thrower of this exception is responsible for logging necessary information
124- logger .info ("Unrecoverable error occurred during plugin invocation" )
125- return InvokeResponse (usage = usage , status_code = 512 , status_code_text = ex .message )
151+ output = await invoke_func (func = func , kwargs = request_dict )
152+ return InvokeResponse (
153+ usage = usage ,
154+ filedata_meta = filedata_meta_model .model_validate (filedata_meta .model_dump ()),
155+ status_code = status .HTTP_200_OK ,
156+ output = output ,
157+ )
126158 except Exception as invoke_error :
127159 logger .error (f"failed to invoke plugin: { invoke_error } " , exc_info = True )
128160 return InvokeResponse (
129161 usage = usage ,
162+ filedata_meta = filedata_meta_model .model_validate (filedata_meta .model_dump ()),
130163 status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
131164 status_code_text = f"failed to invoke plugin: "
132165 f"[{ invoke_error .__class__ .__name__ } ] { invoke_error } " ,
@@ -139,6 +172,11 @@ async def run_job(request: input_schema_model) -> ResponseType:
139172 log_func_and_body (func = func , body = request .json ())
140173 # Create dictionary from pydantic model while preserving underlying types
141174 request_dict = {f : getattr (request , f ) for f in request .model_fields }
175+ # Map FileData back to original dataclass if present
176+ if "file_data" in request_dict :
177+ request_dict ["file_data" ] = FileData .from_dict (
178+ request_dict ["file_data" ].model_dump ()
179+ )
142180 map_inputs (func = func , raw_inputs = request_dict )
143181 if logger .level == LOG_LEVELS .get ("trace" , logging .NOTSET ):
144182 logger .log (level = logger .level , msg = f"passing inputs to function: { request_dict } " )
0 commit comments