Skip to content

Commit 81bea10

Browse files
committed
First pass
1 parent d9d806b commit 81bea10

File tree

3 files changed

+104
-4
lines changed

3 files changed

+104
-4
lines changed

prepline_general/api/general.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,13 @@ def partition_pdf_splits(
166166
pages_per_pdf = int(os.environ.get("UNSTRUCTURED_PARALLEL_MODE_SPLIT_SIZE", 1))
167167

168168
# If it's small enough, just process locally
169+
# (Some kwargs need to be renamed for local partition)
169170
if len(pdf_pages) <= pages_per_pdf:
171+
# Get the test to fail first
172+
# if "hi_res_model_name" in partition_kwargs:
173+
# partition_kwargs['model_name'] = partition_kwargs['hi_res_model_name']
174+
# del partition_kwargs['hi_res_model_name']
175+
170176
return partition(
171177
file=file, file_filename=file_filename, content_type=content_type, **partition_kwargs
172178
)
@@ -317,6 +323,9 @@ def pipeline_api(
317323
)
318324
)
319325

326+
# Be careful of naming differences in api params vs partition params!
327+
# These kwargs are going back into the api
328+
# If needed, rename them in partition_pdf_splits
320329
if file_content_type == "application/pdf" and pdf_parallel_mode_enabled:
321330
elements = partition_pdf_splits(
322331
request,
@@ -328,7 +337,7 @@ def pipeline_api(
328337
# partition_kwargs
329338
encoding=encoding,
330339
include_page_breaks=include_page_breaks,
331-
model_name=hi_res_model_name,
340+
hi_res_model_name=hi_res_model_name,
332341
ocr_languages=ocr_languages,
333342
pdf_infer_table_structure=pdf_infer_table_structure,
334343
skip_infer_table_types=skip_infer_table_types,

scripts/parallel-mode-test.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ declare -a curl_params=(
1515
"-F files=@sample-docs/layout-parser-paper.pdf -F 'strategy=fast'"
1616
"-F files=@sample-docs/layout-parser-paper.pdf -F 'strategy=auto"
1717
"-F files=@sample-docs/layout-parser-paper.pdf -F 'strategy=hi_res'"
18-
"-F files=@sample-docs/layout-parser-paper.pdf -F 'coordinates=true' -F 'strategy=fast'"
19-
"-F files=@sample-docs/layout-parser-paper.pdf -F 'coordinates=true' -F 'strategy=fast' -F 'encoding=utf-8'"
20-
"-F files=@sample-docs/layout-parser-paper.pdf -F 'coordinates=true' -F 'strategy=fast' -F 'include_page_breaks=true'"
18+
"-F files=@sample-docs/layout-parser-paper.pdf -F 'coordinates=true'"
19+
"-F files=@sample-docs/layout-parser-paper.pdf -F 'encoding=utf-8'"
20+
"-F files=@sample-docs/layout-parser-paper.pdf -F 'include_page_breaks=true'"
21+
"-F files=@sample-docs/layout-parser-paper.pdf -F 'hi_res_model_name=yolox'"
2122
)
2223

2324
for params in "${curl_params[@]}"

test_general/api/test_app.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
import pandas as pd
88
from fastapi.testclient import TestClient
99
from unstructured_api_tools.pipelines.api_conventions import get_pipeline_path
10+
from unittest.mock import patch, Mock
11+
12+
import unstructured
1013

1114
from prepline_general.api.app import app
15+
from prepline_general.api import general
1216
import tempfile
1317

1418
MAIN_API_ROUTE = get_pipeline_path("general")
@@ -397,6 +401,92 @@ def json(self):
397401
return self.body
398402

399403

404+
@pytest.mark.only
405+
def test_parallel_mode_params(monkeypatch):
406+
"""
407+
Verify that parallel mode passes all params to local partition.
408+
If you add something to partition_kwargs, you need to explicitly test it here.
409+
"""
410+
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_ENABLED", "true")
411+
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_THREADS", "1")
412+
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_URL", "unused")
413+
414+
# Make this really big so we just call partition
415+
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_SPLIT_SIZE", "500")
416+
417+
# Verify we can pass a non-default value for everything
418+
m_encoding = "foo"
419+
m_hi_res_model_name = "the_big_chipper"
420+
m_strategy = "highest_res"
421+
422+
m_ocr_languages = ["all", "of", "them"]
423+
m_skip_infer_table_types = []
424+
425+
m_pdf_infer_table_structure = True
426+
m_include_page_breaks = True
427+
m_xml_keep_tags = True
428+
429+
# use mock to assert called with
430+
def validate_local_params(
431+
file,
432+
file_filename,
433+
content_type,
434+
encoding,
435+
include_page_breaks,
436+
model_name,
437+
ocr_languages,
438+
pdf_infer_table_structure,
439+
skip_infer_table_types,
440+
strategy,
441+
xml_keep_tags,
442+
):
443+
logger.warn("validating")
444+
assert encoding == m_encoding
445+
assert include_page_breaks == m_include_page_breaks
446+
assert model_name == m_hi_res_model_name
447+
assert ocr_languages == m_ocr_languages
448+
assert pdf_infer_table_structure == m_pdf_infer_table_structure
449+
assert skip_infer_table_types == m_skip_infer_table_types
450+
assert strategy == m_strategy
451+
assert xml_keep_tags == m_xml_keep_tags
452+
453+
return []
454+
455+
mock_partition = Mock()
456+
457+
monkeypatch.setattr(
458+
general,
459+
"partition",
460+
mock_partition,
461+
)
462+
463+
client = TestClient(app)
464+
test_file = Path("sample-docs") / "layout-parser-paper.pdf"
465+
466+
# with patch.object(general, "partition") as mock_partition:
467+
# with patch.object(unstructured.partition.auto, "partition") as mock_partition:
468+
response = client.post(
469+
MAIN_API_ROUTE,
470+
files=[("files", (str(test_file), open(test_file, "rb"), "application/pdf"))],
471+
data={
472+
"encoding": m_encoding,
473+
"include_page_breaks": m_include_page_breaks,
474+
"hi_res_model_name": m_hi_res_model_name,
475+
"ocr_languages": m_ocr_languages,
476+
"pdf_infer_table_structure": m_pdf_infer_table_structure,
477+
"skip_infer_table_types": m_skip_infer_table_types,
478+
"strategy": m_strategy,
479+
# "xml_keep_tags": m_xml_keep_tags,
480+
}
481+
)
482+
483+
mock_partition.assert_called_once_with(
484+
encoding="foo"
485+
)
486+
487+
assert response.status_code == 200
488+
489+
400490
def test_parallel_mode_returns_errors(monkeypatch):
401491
"""
402492
If we get an error sending a page to the api, bubble it up

0 commit comments

Comments
 (0)