diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1ac1fafb..ab2645d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ jobs: setup: strategy: matrix: - python-version: ["3.9","3.10","3.11", "3.12"] + python-version: ["3.10","3.11", "3.12"] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -39,7 +39,7 @@ jobs: lint: strategy: matrix: - python-version: ["3.9","3.10","3.11", "3.12"] + python-version: ["3.10","3.11", "3.12"] runs-on: ubuntu-latest needs: setup steps: @@ -71,7 +71,7 @@ jobs: test: strategy: matrix: - python-version: ["3.9","3.10","3.11", "3.12"] + python-version: ["3.10","3.11", "3.12"] runs-on: ubuntu-latest needs: [setup, lint] steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index 09ca336d..a7929abd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +## 1.0.0 + +* feat: support for Python 3.10+ + +## 0.8.11 + +* feat: remove `donut` model + ## 0.8.10 * feat: unpin `numpy` and bump minimum for `onnxruntime` to be compatible with `numpy>=2` diff --git a/requirements/base.txt b/requirements/base.txt index c65bbce6..e8bdaaeb 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,39 +1,39 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile requirements/base.in # -certifi==2025.1.31 +certifi==2025.4.26 # via requests cffi==1.17.1 # via cryptography -charset-normalizer==3.4.1 +charset-normalizer==3.4.2 # via # pdfminer-six # requests coloredlogs==15.0.1 # via onnxruntime -contourpy==1.3.0 +contourpy==1.3.2 # via matplotlib -cryptography==44.0.2 +cryptography==44.0.3 # via pdfminer-six cycler==0.12.1 # via matplotlib -filelock==3.17.0 +filelock==3.18.0 # via # huggingface-hub # torch # transformers flatbuffers==25.2.10 # via onnxruntime -fonttools==4.56.0 +fonttools==4.58.0 # via matplotlib -fsspec==2025.3.0 +fsspec==2025.3.2 # via # huggingface-hub # torch -huggingface-hub==0.29.3 +huggingface-hub==0.31.2 # via # -r requirements/base.in # timm @@ -43,21 +43,19 @@ humanfriendly==10.0 # via coloredlogs idna==3.10 # via requests -importlib-resources==6.5.2 - # via matplotlib jinja2==3.1.6 # via torch -kiwisolver==1.4.7 +kiwisolver==1.4.8 # via matplotlib markupsafe==3.0.2 # via jinja2 -matplotlib==3.9.4 +matplotlib==3.10.3 # via -r requirements/base.in mpmath==1.3.0 # via sympy -networkx==3.2.1 +networkx==3.4.2 # via torch -numpy==2.0.2 +numpy==2.2.5 # via # -r requirements/base.in # contourpy @@ -69,13 +67,13 @@ numpy==2.0.2 # scipy # torchvision # transformers -onnx==1.17.0 +onnx==1.18.0 # via -r requirements/base.in -onnxruntime==1.19.2 +onnxruntime==1.22.0 # via -r requirements/base.in opencv-python==4.11.0.86 # via -r requirements/base.in -packaging==24.2 +packaging==25.0 # via # huggingface-hub # matplotlib @@ -83,19 +81,19 @@ packaging==24.2 # transformers pandas==2.2.3 # via -r requirements/base.in -pdfminer-six==20240706 +pdfminer-six==20250506 # via -r requirements/base.in -pillow==11.1.0 +pillow==11.2.1 # via # matplotlib # torchvision -protobuf==6.30.0 +protobuf==6.31.0 # via # onnx # onnxruntime pycparser==2.22 # via cffi -pyparsing==3.2.1 +pyparsing==3.2.3 # via matplotlib pypdfium2==4.30.1 # via -r requirements/base.in @@ -105,14 +103,14 @@ python-dateutil==2.9.0.post0 # pandas python-multipart==0.0.20 # via -r requirements/base.in -pytz==2025.1 +pytz==2025.2 # via pandas pyyaml==6.0.2 # via # huggingface-hub # timm # transformers -rapidfuzz==3.12.2 +rapidfuzz==3.13.0 # via -r requirements/base.in regex==2024.11.6 # via transformers @@ -124,11 +122,11 @@ safetensors==0.5.3 # via # timm # transformers -scipy==1.13.1 +scipy==1.15.3 # via -r requirements/base.in six==1.17.0 # via python-dateutil -sympy==1.13.1 +sympy==1.14.0 # via # onnxruntime # torch @@ -136,26 +134,28 @@ timm==1.0.15 # via -r requirements/base.in tokenizers==0.21.1 # via transformers -torch==2.6.0 +torch==2.7.0 # via # -r requirements/base.in # timm # torchvision -torchvision==0.21.0 +torchvision==0.22.0 # via timm tqdm==4.67.1 # via # huggingface-hub # transformers -transformers==4.49.0 +transformers==4.51.3 # via -r requirements/base.in -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via # huggingface-hub + # onnx # torch -tzdata==2025.1 +tzdata==2025.2 # via pandas -urllib3==2.3.0 +urllib3==2.4.0 # via requests -zipp==3.21.0 - # via importlib-resources + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements/dev.txt b/requirements/dev.txt index 26b8668d..7e7384d2 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,10 +1,10 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile requirements/dev.in # -anyio==4.8.0 +anyio==4.9.0 # via # -c requirements/test.txt # httpx @@ -19,7 +19,7 @@ arrow==1.3.0 # via isoduration asttokens==3.0.0 # via stack-data -async-lru==2.0.4 +async-lru==2.0.5 # via jupyterlab attrs==25.3.0 # via @@ -27,13 +27,13 @@ attrs==25.3.0 # referencing babel==2.17.0 # via jupyterlab-server -beautifulsoup4==4.13.3 +beautifulsoup4==4.13.4 # via nbconvert bleach[css]==6.2.0 # via nbconvert build==1.2.2.post1 # via pip-tools -certifi==2025.1.31 +certifi==2025.4.26 # via # -c requirements/base.txt # -c requirements/test.txt @@ -44,12 +44,12 @@ cffi==1.17.1 # via # -c requirements/base.txt # argon2-cffi-bindings -charset-normalizer==3.4.1 +charset-normalizer==3.4.2 # via # -c requirements/base.txt # -c requirements/test.txt # requests -click==8.1.8 +click==8.2.0 # via # -c requirements/test.txt # pip-tools @@ -57,7 +57,7 @@ comm==0.2.2 # via # ipykernel # ipywidgets -contourpy==1.3.0 +contourpy==1.3.2 # via # -c requirements/base.txt # matplotlib @@ -65,32 +65,27 @@ cycler==0.12.1 # via # -c requirements/base.txt # matplotlib -debugpy==1.8.13 +debugpy==1.8.14 # via ipykernel decorator==5.2.1 # via ipython defusedxml==0.7.1 # via nbconvert -exceptiongroup==1.2.2 - # via - # -c requirements/test.txt - # anyio - # ipython executing==2.2.0 # via stack-data fastjsonschema==2.21.1 # via nbformat -fonttools==4.56.0 +fonttools==4.58.0 # via # -c requirements/base.txt # matplotlib fqdn==1.5.1 # via jsonschema -h11==0.14.0 +h11==0.16.0 # via # -c requirements/test.txt # httpcore -httpcore==1.0.7 +httpcore==1.0.9 # via # -c requirements/test.txt # httpx @@ -106,30 +101,20 @@ idna==3.10 # httpx # jsonschema # requests -importlib-metadata==8.6.1 - # via - # build - # jupyter-client - # jupyter-lsp - # jupyterlab - # jupyterlab-server - # nbconvert -importlib-resources==6.5.2 - # via - # -c requirements/base.txt - # matplotlib ipykernel==6.29.5 # via # jupyter # jupyter-console # jupyterlab -ipython==8.18.1 +ipython==9.2.0 # via # -r requirements/dev.in # ipykernel # ipywidgets # jupyter-console -ipywidgets==8.1.5 +ipython-pygments-lexers==1.1.1 + # via ipython +ipywidgets==8.1.7 # via jupyter isoduration==20.11.0 # via jsonschema @@ -142,7 +127,7 @@ jinja2==3.1.6 # jupyterlab # jupyterlab-server # nbconvert -json5==0.10.0 +json5==0.12.0 # via jupyterlab-server jsonpointer==3.0.0 # via jsonschema @@ -151,7 +136,7 @@ jsonschema[format-nongpl]==4.23.0 # jupyter-events # jupyterlab-server # nbformat -jsonschema-specifications==2024.10.1 +jsonschema-specifications==2025.4.1 # via jsonschema jupyter==1.1.1 # via -r requirements/dev.in @@ -177,7 +162,7 @@ jupyter-events==0.12.0 # via jupyter-server jupyter-lsp==2.2.5 # via jupyterlab -jupyter-server==2.15.0 +jupyter-server==2.16.0 # via # jupyter-lsp # jupyterlab @@ -186,7 +171,7 @@ jupyter-server==2.15.0 # notebook-shim jupyter-server-terminals==0.5.3 # via jupyter-server -jupyterlab==4.3.5 +jupyterlab==4.4.2 # via # jupyter # notebook @@ -196,9 +181,9 @@ jupyterlab-server==2.27.3 # via # jupyterlab # notebook -jupyterlab-widgets==3.0.13 +jupyterlab-widgets==3.0.15 # via ipywidgets -kiwisolver==1.4.7 +kiwisolver==1.4.8 # via # -c requirements/base.txt # matplotlib @@ -207,7 +192,7 @@ markupsafe==3.0.2 # -c requirements/base.txt # jinja2 # nbconvert -matplotlib==3.9.4 +matplotlib==3.10.3 # via # -c requirements/base.txt # -r requirements/dev.in @@ -215,7 +200,7 @@ matplotlib-inline==0.1.7 # via # ipykernel # ipython -mistune==3.1.2 +mistune==3.1.3 # via nbconvert nbclient==0.10.2 # via nbconvert @@ -230,20 +215,20 @@ nbformat==5.10.4 # nbconvert nest-asyncio==1.6.0 # via ipykernel -notebook==7.3.2 +notebook==7.4.2 # via jupyter notebook-shim==0.2.4 # via # jupyterlab # notebook -numpy==2.0.2 +numpy==2.2.5 # via # -c requirements/base.txt # contourpy # matplotlib overrides==7.7.0 # via jupyter-server -packaging==24.2 +packaging==25.0 # via # -c requirements/base.txt # -c requirements/test.txt @@ -261,20 +246,20 @@ parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -pillow==11.1.0 +pillow==11.2.1 # via # -c requirements/base.txt # -c requirements/test.txt # matplotlib pip-tools==7.4.1 # via -r requirements/dev.in -platformdirs==4.3.6 +platformdirs==4.3.8 # via # -c requirements/test.txt # jupyter-core prometheus-client==0.21.1 # via jupyter-server -prompt-toolkit==3.0.50 +prompt-toolkit==3.0.51 # via # ipython # jupyter-console @@ -293,9 +278,10 @@ pycparser==2.22 pygments==2.19.1 # via # ipython + # ipython-pygments-lexers # jupyter-console # nbconvert -pyparsing==3.2.1 +pyparsing==3.2.3 # via # -c requirements/base.txt # matplotlib @@ -316,7 +302,7 @@ pyyaml==6.0.2 # -c requirements/base.txt # -c requirements/test.txt # jupyter-events -pyzmq==26.3.0 +pyzmq==26.4.0 # via # ipykernel # jupyter-client @@ -340,7 +326,7 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rpds-py==0.23.1 +rpds-py==0.25.0 # via # jsonschema # referencing @@ -355,7 +341,7 @@ sniffio==1.3.1 # via # -c requirements/test.txt # anyio -soupsieve==2.6 +soupsieve==2.7 # via beautifulsoup4 stack-data==0.6.3 # via ipython @@ -365,12 +351,6 @@ terminado==0.18.1 # jupyter-server-terminals tinycss2==1.4.0 # via bleach -tomli==2.2.1 - # via - # -c requirements/test.txt - # build - # jupyterlab - # pip-tools tornado==6.4.2 # via # ipykernel @@ -397,20 +377,16 @@ traitlets==5.14.3 # nbformat types-python-dateutil==2.9.0.20241206 # via arrow -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via # -c requirements/base.txt # -c requirements/test.txt # anyio - # async-lru # beautifulsoup4 - # ipython - # mistune - # python-json-logger # referencing uri-template==1.3.0 # via jsonschema -urllib3==2.3.0 +urllib3==2.4.0 # via # -c requirements/base.txt # -c requirements/test.txt @@ -427,13 +403,8 @@ websocket-client==1.8.0 # via jupyter-server wheel==0.45.1 # via pip-tools -widgetsnbextension==4.0.13 +widgetsnbextension==4.0.14 # via ipywidgets -zipp==3.21.0 - # via - # -c requirements/base.txt - # importlib-metadata - # importlib-resources # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/requirements/test.txt b/requirements/test.txt index 13ded3ce..1c474408 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,56 +1,52 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile requirements/test.in # -anyio==4.8.0 +anyio==4.9.0 # via httpx black==25.1.0 # via -r requirements/test.in -certifi==2025.1.31 +certifi==2025.4.26 # via # -c requirements/base.txt # httpcore # httpx # requests -charset-normalizer==3.4.1 +charset-normalizer==3.4.2 # via # -c requirements/base.txt # requests -click==8.1.8 +click==8.2.0 # via # -r requirements/test.in # black -coverage[toml]==7.6.12 +coverage[toml]==7.8.0 # via # -r requirements/test.in # pytest-cov -exceptiongroup==1.2.2 - # via - # anyio - # pytest -filelock==3.17.0 +filelock==3.18.0 # via # -c requirements/base.txt # huggingface-hub -flake8==7.1.2 +flake8==7.2.0 # via # -r requirements/test.in # flake8-docstrings flake8-docstrings==1.7.0 # via -r requirements/test.in -fsspec==2025.3.0 +fsspec==2025.3.2 # via # -c requirements/base.txt # huggingface-hub -h11==0.14.0 +h11==0.16.0 # via httpcore -httpcore==1.0.7 +httpcore==1.0.9 # via httpx httpx==0.28.1 # via -r requirements/test.in -huggingface-hub==0.29.3 +huggingface-hub==0.31.2 # via # -c requirements/base.txt # -r requirements/test.in @@ -60,17 +56,17 @@ idna==3.10 # anyio # httpx # requests -iniconfig==2.0.0 +iniconfig==2.1.0 # via pytest mccabe==0.7.0 # via flake8 mypy==1.15.0 # via -r requirements/test.in -mypy-extensions==1.0.0 +mypy-extensions==1.1.0 # via # black # mypy -packaging==24.2 +packaging==25.0 # via # -c requirements/base.txt # black @@ -80,25 +76,25 @@ pathspec==0.12.1 # via black pdf2image==1.17.0 # via -r requirements/test.in -pillow==11.1.0 +pillow==11.2.1 # via # -c requirements/base.txt # pdf2image -platformdirs==4.3.6 +platformdirs==4.3.8 # via black -pluggy==1.5.0 +pluggy==1.6.0 # via pytest -pycodestyle==2.12.1 +pycodestyle==2.13.0 # via flake8 pydocstyle==6.3.0 # via flake8-docstrings -pyflakes==3.2.0 +pyflakes==3.3.2 # via flake8 pytest==8.3.5 # via # pytest-cov # pytest-mock -pytest-cov==6.0.0 +pytest-cov==6.1.1 # via -r requirements/test.in pytest-mock==3.14.0 # via -r requirements/test.in @@ -110,32 +106,25 @@ requests==2.32.3 # via # -c requirements/base.txt # huggingface-hub -ruff==0.10.0 +ruff==0.11.10 # via -r requirements/test.in sniffio==1.3.1 # via anyio -snowballstemmer==2.2.0 +snowballstemmer==3.0.1 # via pydocstyle -tomli==2.2.1 - # via - # black - # coverage - # mypy - # pytest tqdm==4.67.1 # via # -c requirements/base.txt # huggingface-hub -types-pyyaml==6.0.12.20241230 +types-pyyaml==6.0.12.20250402 # via -r requirements/test.in -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via # -c requirements/base.txt # anyio - # black # huggingface-hub # mypy -urllib3==2.3.0 +urllib3==2.4.0 # via # -c requirements/base.txt # requests diff --git a/setup.py b/setup.py index 968316b1..99ca0532 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def load_text_from_file(filename: str): long_description_content_type="text/markdown", keywords="NLP PDF HTML CV XML parsing preprocessing", url="https://github.com/Unstructured-IO/unstructured-inference", - python_requires=">=3.7.0", + python_requires=">=3.10", classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -63,8 +63,9 @@ def load_text_from_file(filename: str): "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], author="Unstructured Technologies", diff --git a/test_unstructured_inference/models/test_donut.py b/test_unstructured_inference/models/test_donut.py deleted file mode 100644 index 73861717..00000000 --- a/test_unstructured_inference/models/test_donut.py +++ /dev/null @@ -1,79 +0,0 @@ -import pytest -from PIL import Image -from transformers import DonutSwinModel - -from unstructured_inference.models import donut - - -@pytest.mark.parametrize( - ("model_path", "processor_path", "config_path"), - [ - ("crispy_donut_path", "crispy_proc", "crispy_config"), - ("cherry_donut_path", "cherry_proc", "cherry_config"), - ], -) -def test_load_donut_model_raises_when_not_available(model_path, processor_path, config_path): - with pytest.raises(ImportError): - donut_model = donut.UnstructuredDonutModel() - donut_model.initialize( - model=model_path, - processor=processor_path, - config=config_path, - task_prompt="", - ) - - -@pytest.mark.parametrize( - ("model_path", "processor_path", "config_path"), - [ - ( - "unstructuredio/donut-base-sroie", - "unstructuredio/donut-base-sroie", - "unstructuredio/donut-base-sroie", - ), - ], -) -def test_load_donut_model(model_path, processor_path, config_path): - donut_model = donut.UnstructuredDonutModel() - donut_model.initialize( - model=model_path, - processor=processor_path, - config=config_path, - task_prompt="", - ) - assert type(donut_model.model.encoder) is DonutSwinModel - - -@pytest.fixture() -def sample_receipt_transcript(): - return { - "total": "46.00", - "date": "20/03/2018", - "company": "UROKO JAPANESE CUISINE SDN BHD", - "address": "22A-1, JALAN 17/54, SECTION 17, 46400 PETALING JAYA, SELANGOR.", - } - - -@pytest.mark.skip() -@pytest.mark.parametrize( - ("model_path", "processor_path", "config_path"), - [ - ( - "unstructuredio/donut-base-sroie", - "unstructuredio/donut-base-sroie", - "unstructuredio/donut-base-sroie", - ), - ], -) -def test_donut_prediction(model_path, processor_path, config_path, sample_receipt_transcript): - donut_model = donut.UnstructuredDonutModel() - donut_model.initialize( - model=model_path, - processor=processor_path, - config=config_path, - task_prompt="", - ) - image_path = "./sample-docs/receipt-sample.jpg" - with Image.open(image_path) as image: - prediction = donut_model.predict(image) - assert prediction == sample_receipt_transcript diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 3feb9ed5..03c9b1fd 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -566,7 +566,7 @@ def mocked_ocr_tokens(): ], ) def test_load_table_model_raises_when_not_available(model_path): - with pytest.raises(ImportError): + with pytest.raises(OSError): table_model = tables.UnstructuredTableTransformerModel() table_model.initialize(model=model_path) diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 4ab59b42..fd3eb234 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.8.10" # pragma: no cover +__version__ = "1.0.0" # pragma: no cover diff --git a/unstructured_inference/models/base.py b/unstructured_inference/models/base.py index eef0844c..826bf5a5 100644 --- a/unstructured_inference/models/base.py +++ b/unstructured_inference/models/base.py @@ -46,7 +46,7 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel: # TODO(alan): These cases are similar enough that we can probably do them all together with # importlib - global models + global models # noqa if model_name is None: default_name_from_env = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_NAME") diff --git a/unstructured_inference/models/donut.py b/unstructured_inference/models/donut.py deleted file mode 100644 index bc60d2c6..00000000 --- a/unstructured_inference/models/donut.py +++ /dev/null @@ -1,79 +0,0 @@ -import logging -from pathlib import Path -from typing import Optional, Union - -import torch -from PIL import Image as PILImage -from transformers import ( - DonutProcessor, - VisionEncoderDecoderConfig, - VisionEncoderDecoderModel, -) - -from unstructured_inference.models.unstructuredmodel import UnstructuredModel - - -class UnstructuredDonutModel(UnstructuredModel): - """Unstructured model wrapper for Donut image transformer.""" - - def predict(self, x: PILImage.Image): - """Make prediction using donut model""" - super().predict(x) - return self.run_prediction(x) - - def initialize( - self, - model: Union[str, Path, VisionEncoderDecoderModel] = None, - processor: Union[str, Path, DonutProcessor] = None, - config: Optional[Union[str, Path, VisionEncoderDecoderConfig]] = None, - task_prompt: Optional[str] = "", - device: Optional[str] = "cuda" if torch.cuda.is_available() else "cpu", - ): - """Loads the donut model using the specified parameters""" - - self.task_prompt = task_prompt - self.device = device - - try: - if not isinstance(config, VisionEncoderDecoderModel): - config = VisionEncoderDecoderConfig.from_pretrained(config) - - logging.info("Loading the Donut model and processor...") - self.processor = DonutProcessor.from_pretrained(processor) - self.model = VisionEncoderDecoderModel.from_pretrained(model, config=config) - - except EnvironmentError: - logging.critical("Failed to initialize the model.") - logging.critical( - "Ensure that the Donut parameters config, model and processor are correct", - ) - raise ImportError("Review the parameters to initialize a UnstructuredDonutModel obj") - self.model.to(device) - - def run_prediction(self, x: PILImage.Image): - """Internal prediction method.""" - pixel_values = self.processor(x, return_tensors="pt").pixel_values - decoder_input_ids = self.processor.tokenizer( - self.task_prompt, - add_special_tokens=False, - return_tensors="pt", - ).input_ids - outputs = self.model.generate( - pixel_values.to(self.device), - decoder_input_ids=decoder_input_ids.to(self.device), - max_length=self.model.decoder.config.max_position_embeddings, - early_stopping=True, - pad_token_id=self.processor.tokenizer.pad_token_id, - eos_token_id=self.processor.tokenizer.eos_token_id, - use_cache=True, - num_beams=1, - bad_words_ids=[[self.processor.tokenizer.unk_token_id]], - return_dict_in_generate=True, - ) - prediction = self.processor.batch_decode(outputs.sequences)[0] - # NOTE(alan): As of right now I think this would not work if passed in as the model to - # DocumentLayout.from_file and similar functions that take a model object as input. This - # produces image-to-text inferences rather than image-to-bboxes, so we actually need to - # hook it up in a different way. - prediction = self.processor.token2json(prediction) - return prediction diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index c390378e..7c0dfbe6 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -56,12 +56,12 @@ def predict( def initialize( self, - model: Union[str, Path, TableTransformerForObjectDetection] = None, + model: Union[str, Path, TableTransformerForObjectDetection], device: Optional[str] = "cuda" if torch.cuda.is_available() else "cpu", ): """Loads the donut model using the specified parameters""" self.device = device - self.feature_extractor = DetrImageProcessor() + self.feature_extractor = DetrImageProcessor.from_pretrained(model) try: logger.info("Loading the table structure model ...") @@ -83,7 +83,7 @@ def get_structure( self, x: PILImage.Image, pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD, - ) -> dict: + ) -> TableTransformerObjectDetectionOutput: """get the table structure as a dictionary contaning different types of elements as key-value pairs; check table-transformer documentation for more information""" with torch.no_grad(): @@ -135,7 +135,7 @@ def run_prediction( def load_agent(): """Loads the Table agent as a global variable to ensure that we only load it once.""" - global tables_agent + global tables_agent # noqa if not hasattr(tables_agent, "model"): logger.info("Loading the Table agent ...") @@ -173,7 +173,7 @@ def get_class_map(data_type: str): } -def recognize(outputs: dict, img: PILImage.Image, tokens: list): +def recognize(outputs: TableTransformerObjectDetectionOutput, img: PILImage.Image, tokens: list): """Recognize table elements.""" str_class_name2idx = get_class_map("structure") str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}