44import json
55import logging
66from functools import partial
7- from typing import Any , Callable , Optional
7+ from typing import Any , Callable , Optional , Union
88
99from fastapi import FastAPI , status
1010from fastapi .responses import StreamingResponse
1111from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
12- from pydantic import BaseModel
12+ from pydantic import BaseModel , Field , create_model
1313from starlette .responses import RedirectResponse
1414from uvicorn .config import LOG_LEVELS
1515from uvicorn .importer import import_from_string
2323 get_schema_dict ,
2424 map_inputs ,
2525)
26- from unstructured_platform_plugins .schema import FileDataMeta , UsageData
26+ from unstructured_platform_plugins .schema import FileDataMeta , NewRecord , UsageData
2727from unstructured_platform_plugins .schema .json_schema import (
2828 schema_to_base_model ,
2929)
@@ -67,6 +67,30 @@ def check_precheck_func(precheck_func: Callable):
6767 raise ValueError (f"no output should exist for precheck function, found: { outputs } " )
6868
6969
70+ def is_optional (t : Any ) -> bool :
71+ return (
72+ hasattr (t , "__origin__" )
73+ and t .__origin__ is Union
74+ and hasattr (t , "__args__" )
75+ and type (None ) in t .__args__
76+ )
77+
78+
79+ def update_filedata_model (new_type ) -> BaseModel :
80+ field_info = NewRecord .model_fields ["contents" ]
81+ if is_optional (new_type ):
82+ field_info .default = None
83+ new_record_model = create_model (
84+ NewRecord .__name__ , contents = (new_type , field_info ), __base__ = NewRecord
85+ )
86+ new_filedata_model = create_model (
87+ FileDataMeta .__name__ ,
88+ new_records = (list [new_record_model ], Field (default_factory = list )),
89+ __base__ = FileDataMeta ,
90+ )
91+ return new_filedata_model
92+
93+
7094def wrap_in_fastapi (
7195 func : Callable ,
7296 plugin_id : str ,
@@ -80,11 +104,12 @@ def wrap_in_fastapi(
80104 fastapi_app = FastAPI ()
81105
82106 response_type = get_output_sig (func )
107+ filedata_meta_model = update_filedata_model (response_type )
83108
84109 class InvokeResponse (BaseModel ):
85110 usage : list [UsageData ]
86111 status_code : int
87- filedata_meta : FileDataMeta
112+ filedata_meta : filedata_meta_model
88113 status_code_text : Optional [str ] = None
89114 output : Optional [response_type ] = None
90115
@@ -113,7 +138,9 @@ async def _stream_response():
113138 async for output in func (** (request_dict or {})):
114139 yield InvokeResponse (
115140 usage = usage ,
116- filedata_meta = filedata_meta ,
141+ filedata_meta = filedata_meta_model .model_validate (
142+ filedata_meta .model_dump ()
143+ ),
117144 status_code = status .HTTP_200_OK ,
118145 output = output ,
119146 ).model_dump_json () + "\n "
@@ -123,15 +150,15 @@ async def _stream_response():
123150 output = await invoke_func (func = func , kwargs = request_dict )
124151 return InvokeResponse (
125152 usage = usage ,
126- filedata_meta = filedata_meta ,
153+ filedata_meta = filedata_meta_model . model_validate ( filedata_meta . model_dump ()) ,
127154 status_code = status .HTTP_200_OK ,
128155 output = output ,
129156 )
130157 except Exception as invoke_error :
131158 logger .error (f"failed to invoke plugin: { invoke_error } " , exc_info = True )
132159 return InvokeResponse (
133160 usage = usage ,
134- filedata_meta = filedata_meta ,
161+ filedata_meta = filedata_meta_model . model_validate ( filedata_meta . model_dump ()) ,
135162 status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
136163 status_code_text = f"failed to invoke plugin: "
137164 f"[{ invoke_error .__class__ .__name__ } ] { invoke_error } " ,
0 commit comments