|
7 | 7 | import pandas as pd |
8 | 8 | from fastapi.testclient import TestClient |
9 | 9 | from unstructured_api_tools.pipelines.api_conventions import get_pipeline_path |
| 10 | +from unittest.mock import patch, Mock |
| 11 | + |
| 12 | +import unstructured |
10 | 13 |
|
11 | 14 | from prepline_general.api.app import app |
| 15 | +from prepline_general.api import general |
12 | 16 | import tempfile |
13 | 17 |
|
14 | 18 | MAIN_API_ROUTE = get_pipeline_path("general") |
@@ -397,6 +401,92 @@ def json(self): |
397 | 401 | return self.body |
398 | 402 |
|
399 | 403 |
|
| 404 | +@pytest.mark.only |
| 405 | +def test_parallel_mode_params(monkeypatch): |
| 406 | + """ |
| 407 | + Verify that parallel mode passes all params to local partition. |
| 408 | + If you add something to partition_kwargs, you need to explicitly test it here. |
| 409 | + """ |
| 410 | + monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_ENABLED", "true") |
| 411 | + monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_THREADS", "1") |
| 412 | + monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_URL", "unused") |
| 413 | + |
| 414 | + # Make this really big so we just call partition |
| 415 | + monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_SPLIT_SIZE", "500") |
| 416 | + |
| 417 | + # Verify we can pass a non-default value for everything |
| 418 | + m_encoding = "foo" |
| 419 | + m_hi_res_model_name = "the_big_chipper" |
| 420 | + m_strategy = "highest_res" |
| 421 | + |
| 422 | + m_ocr_languages = ["all", "of", "them"] |
| 423 | + m_skip_infer_table_types = [] |
| 424 | + |
| 425 | + m_pdf_infer_table_structure = True |
| 426 | + m_include_page_breaks = True |
| 427 | + m_xml_keep_tags = True |
| 428 | + |
| 429 | + # use mock to assert called with |
| 430 | + def validate_local_params( |
| 431 | + file, |
| 432 | + file_filename, |
| 433 | + content_type, |
| 434 | + encoding, |
| 435 | + include_page_breaks, |
| 436 | + model_name, |
| 437 | + ocr_languages, |
| 438 | + pdf_infer_table_structure, |
| 439 | + skip_infer_table_types, |
| 440 | + strategy, |
| 441 | + xml_keep_tags, |
| 442 | + ): |
| 443 | + logger.warn("validating") |
| 444 | + assert encoding == m_encoding |
| 445 | + assert include_page_breaks == m_include_page_breaks |
| 446 | + assert model_name == m_hi_res_model_name |
| 447 | + assert ocr_languages == m_ocr_languages |
| 448 | + assert pdf_infer_table_structure == m_pdf_infer_table_structure |
| 449 | + assert skip_infer_table_types == m_skip_infer_table_types |
| 450 | + assert strategy == m_strategy |
| 451 | + assert xml_keep_tags == m_xml_keep_tags |
| 452 | + |
| 453 | + return [] |
| 454 | + |
| 455 | + mock_partition = Mock() |
| 456 | + |
| 457 | + monkeypatch.setattr( |
| 458 | + general, |
| 459 | + "partition", |
| 460 | + mock_partition, |
| 461 | + ) |
| 462 | + |
| 463 | + client = TestClient(app) |
| 464 | + test_file = Path("sample-docs") / "layout-parser-paper.pdf" |
| 465 | + |
| 466 | + # with patch.object(general, "partition") as mock_partition: |
| 467 | + # with patch.object(unstructured.partition.auto, "partition") as mock_partition: |
| 468 | + response = client.post( |
| 469 | + MAIN_API_ROUTE, |
| 470 | + files=[("files", (str(test_file), open(test_file, "rb"), "application/pdf"))], |
| 471 | + data={ |
| 472 | + "encoding": m_encoding, |
| 473 | + "include_page_breaks": m_include_page_breaks, |
| 474 | + "hi_res_model_name": m_hi_res_model_name, |
| 475 | + "ocr_languages": m_ocr_languages, |
| 476 | + "pdf_infer_table_structure": m_pdf_infer_table_structure, |
| 477 | + "skip_infer_table_types": m_skip_infer_table_types, |
| 478 | + "strategy": m_strategy, |
| 479 | + # "xml_keep_tags": m_xml_keep_tags, |
| 480 | + } |
| 481 | + ) |
| 482 | + |
| 483 | + mock_partition.assert_called_once_with( |
| 484 | + encoding="foo" |
| 485 | + ) |
| 486 | + |
| 487 | + assert response.status_code == 200 |
| 488 | + |
| 489 | + |
400 | 490 | def test_parallel_mode_returns_errors(monkeypatch): |
401 | 491 | """ |
402 | 492 | If we get an error sending a page to the api, bubble it up |
|
0 commit comments