diff --git a/CHANGELOG.md b/CHANGELOG.md index 952e410..a27e42f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.0.29 + +* **Support persisting file data changes** + ## 0.0.28 * **Isolate what gets bundled in package** diff --git a/test/api/test_api.py b/test/api/test_api.py index cf6fdf5..8bec1df 100644 --- a/test/api/test_api.py +++ b/test/api/test_api.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, Union import pytest from fastapi.testclient import TestClient @@ -25,6 +25,7 @@ class InvokeResponse(BaseModel): filedata_meta: FileDataMeta status_code_text: Optional[str] = None output: Optional[Any] = None + file_data: Optional[Union[FileData, BatchFileData]] = None def generic_validation(self): assert self.status_code == 200 @@ -121,6 +122,9 @@ def test_filedata_meta(file_data): filedata_meta = invoke_response.filedata_meta assert len(filedata_meta.new_records) == 15 assert filedata_meta.terminate_current + file_data = invoke_response.file_data + assert file_data + assert file_data.metadata.record_locator.get("key") == "value" assert not invoke_response.output diff --git a/test/assets/filedata_meta.py b/test/assets/filedata_meta.py index bfaa613..b3f4694 100644 --- a/test/assets/filedata_meta.py +++ b/test/assets/filedata_meta.py @@ -18,6 +18,7 @@ class Output(BaseModel): def process_input( i: Input, file_data: Union[FileData, BatchFileData], filedata_meta: FileDataMeta ) -> Optional[Output]: + file_data.metadata.record_locator = {"key": "value"} if i.m > 10: filedata_meta.terminate_current = True new_content = [ diff --git a/unstructured_platform_plugins/__version__.py b/unstructured_platform_plugins/__version__.py index 3c17069..54489aa 100644 --- a/unstructured_platform_plugins/__version__.py +++ b/unstructured_platform_plugins/__version__.py @@ -1 +1 @@ -__version__ = "0.0.28" # pragma: no cover +__version__ = "0.0.29" # pragma: no cover diff --git a/unstructured_platform_plugins/etl_uvicorn/api_generator.py b/unstructured_platform_plugins/etl_uvicorn/api_generator.py index aabc084..4d22b8b 100644 --- a/unstructured_platform_plugins/etl_uvicorn/api_generator.py +++ b/unstructured_platform_plugins/etl_uvicorn/api_generator.py @@ -11,7 +11,7 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from pydantic import BaseModel, Field, create_model from starlette.responses import RedirectResponse -from unstructured_ingest.data_types.file_data import file_data_from_dict +from unstructured_ingest.data_types.file_data import BatchFileData, FileData, file_data_from_dict from uvicorn.config import LOG_LEVELS from uvicorn.importer import import_from_string @@ -31,6 +31,8 @@ schema_to_base_model, ) +FileDataType = Union[FileData, BatchFileData] + class EtlApiException(Exception): pass @@ -137,7 +139,8 @@ def _wrap_in_fastapi( class InvokeResponse(BaseModel): usage: list[UsageData] status_code: int - filedata_meta: Optional[filedata_meta_model] + file_data: Optional[FileDataType] = None + filedata_meta: Optional[filedata_meta_model] = None status_code_text: Optional[str] = None output: Optional[response_type] = None message_channels: MessageChannels = Field(default_factory=MessageChannels) @@ -177,6 +180,7 @@ async def _stream_response(): ), status_code=status.HTTP_200_OK, output=output, + file_data=request_dict.get("file_data", None), ).model_dump_json() + "\n" ) @@ -202,6 +206,7 @@ async def _stream_response(): filedata_meta=filedata_meta_model.model_validate(filedata_meta.model_dump()), status_code=status.HTTP_200_OK, output=output, + file_data=request_dict.get("file_data", None), ) except UnrecoverableException as ex: logger.info("Unrecoverable error occurred during plugin invocation") @@ -211,6 +216,7 @@ async def _stream_response(): status_code=512, status_code_text=ex.message, filedata_meta=filedata_meta_model.model_validate(filedata_meta.model_dump()), + file_data=request_dict.get("file_data", None), ) except Exception as invoke_error: logger.error(f"failed to invoke plugin: {invoke_error}", exc_info=True) @@ -221,6 +227,7 @@ async def _stream_response(): filedata_meta=filedata_meta_model.model_validate(filedata_meta.model_dump()), status_code=http_error.status_code, status_code_text=f"[{invoke_error.__class__.__name__}] {invoke_error}", + file_data=request_dict.get("file_data", None), ) if input_schema_model.model_fields: