diff --git a/CHANGELOG.md b/CHANGELOG.md index 888c75e..08beff7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.0.24 + +* **Add support for passing messages back other than errors** + ## 0.0.23 * **Handle errors in streaming responses** diff --git a/unstructured_platform_plugins/__version__.py b/unstructured_platform_plugins/__version__.py index 7ec71da..d406886 100644 --- a/unstructured_platform_plugins/__version__.py +++ b/unstructured_platform_plugins/__version__.py @@ -1 +1 @@ -__version__ = "0.0.23" # pragma: no cover +__version__ = "0.0.24" # pragma: no cover diff --git a/unstructured_platform_plugins/etl_uvicorn/api_generator.py b/unstructured_platform_plugins/etl_uvicorn/api_generator.py index 132146b..312b1da 100644 --- a/unstructured_platform_plugins/etl_uvicorn/api_generator.py +++ b/unstructured_platform_plugins/etl_uvicorn/api_generator.py @@ -39,6 +39,11 @@ class EtlApiException(Exception): logger = logging.getLogger("uvicorn.error") +class MessageChannels(BaseModel): + infos: list[str] = Field(default_factory=list) + warnings: list[str] = Field(default_factory=list) + + def log_func_and_body(func: Callable, body: Optional[str] = None) -> None: msg = None if logger.level == LOG_LEVELS.get("debug", logging.NOTSET): @@ -135,6 +140,7 @@ class InvokeResponse(BaseModel): filedata_meta: Optional[filedata_meta_model] status_code_text: Optional[str] = None output: Optional[response_type] = None + message_channels: MessageChannels = Field(default_factory=MessageChannels) input_schema = get_input_schema(func, omit=["usage", "filedata_meta"]) input_schema_model = schema_to_base_model(input_schema) @@ -146,11 +152,14 @@ class InvokeResponse(BaseModel): async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> ResponseType: usage: list[UsageData] = [] filedata_meta = FileDataMeta() + message_channels = MessageChannels() request_dict = kwargs if kwargs else {} if "usage" in inspect.signature(func).parameters: request_dict["usage"] = usage else: logger.warning("usage data not an expected parameter, omitting") + if "message_channels" in inspect.signature(func).parameters: + request_dict["message_channels"] = message_channels if "filedata_meta" in inspect.signature(func).parameters: request_dict["filedata_meta"] = filedata_meta try: @@ -161,6 +170,7 @@ async def _stream_response(): async for output in func(**(request_dict or {})): yield InvokeResponse( usage=usage, + message_channels=message_channels, filedata_meta=filedata_meta_model.model_validate( filedata_meta.model_dump() ), @@ -171,6 +181,7 @@ async def _stream_response(): logger.error(f"Failure streaming response: {e}", exc_info=True) yield InvokeResponse( usage=usage, + message_channels=message_channels, filedata_meta=None, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code_text=f"[{e.__class__.__name__}] {e}", @@ -181,6 +192,7 @@ async def _stream_response(): output = await invoke_func(func=func, kwargs=request_dict) return InvokeResponse( usage=usage, + message_channels=message_channels, filedata_meta=filedata_meta_model.model_validate(filedata_meta.model_dump()), status_code=status.HTTP_200_OK, output=output, @@ -189,6 +201,7 @@ async def _stream_response(): logger.info("Unrecoverable error occurred during plugin invocation") return InvokeResponse( usage=usage, + message_channels=message_channels, status_code=512, status_code_text=ex.message, filedata_meta=filedata_meta_model.model_validate(filedata_meta.model_dump()), @@ -198,6 +211,7 @@ async def _stream_response(): http_error = wrap_error(invoke_error) return InvokeResponse( usage=usage, + message_channels=message_channels, filedata_meta=filedata_meta_model.model_validate(filedata_meta.model_dump()), status_code=http_error.status_code, status_code_text=f"[{invoke_error.__class__.__name__}] {invoke_error}",