|
5 | 5 |
|
6 | 6 | import httpx |
7 | 7 | import json |
| 8 | +import os |
8 | 9 | import pytest |
9 | 10 | import requests |
10 | 11 | from deepdiff import DeepDiff |
|
19 | 20 | from unstructured_client._hooks.custom import form_utils |
20 | 21 | from unstructured_client._hooks.custom import split_pdf_hook |
21 | 22 |
|
22 | | -FAKE_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" |
| 23 | +FAKE_KEY = os.getenv("UNSTRUCTURED_API_KEY") or "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" |
23 | 24 |
|
24 | 25 |
|
25 | 26 | @pytest.mark.parametrize("concurrency_level", [1, 2, 5]) |
@@ -472,3 +473,43 @@ async def mock_send(_, request: httpx.Request, **kwargs): |
472 | 473 | assert mock_endpoint_called |
473 | 474 |
|
474 | 475 | assert res.status_code == 200 |
| 476 | + |
| 477 | + |
| 478 | +@pytest.mark.parametrize( |
| 479 | + ("filename", "chunking_strategy", "expected_elements_num"), |
| 480 | + [ |
| 481 | + ## -- Paid strategy -- |
| 482 | + ("_sample_docs/layout-parser-paper.pdf", "by_page", 16), # 16 pages, 133 elements w/o chunking |
| 483 | + ("_sample_docs/layout-parser-paper.pdf", shared.ChunkingStrategy.BY_PAGE, 16), |
| 484 | + # -- Open source strategy -- |
| 485 | + ("_sample_docs/layout-parser-paper.pdf", "by_title", -1), # unsure what the correct number is atm |
| 486 | + ("_sample_docs/layout-parser-paper.pdf", shared.ChunkingStrategy.BY_TITLE, -1), |
| 487 | + ], |
| 488 | +) |
| 489 | +def test_chunking( |
| 490 | + filename: str, |
| 491 | + chunking_strategy: str| shared.ChunkingStrategy, |
| 492 | + expected_elements_num: int, |
| 493 | +): |
| 494 | + |
| 495 | + client = UnstructuredClient(api_key_auth=FAKE_KEY) |
| 496 | + |
| 497 | + with open(filename, "rb") as f: |
| 498 | + files = shared.Files( |
| 499 | + content=f.read(), |
| 500 | + file_name=filename, |
| 501 | + ) |
| 502 | + |
| 503 | + parameters = shared.PartitionParameters( |
| 504 | + files=files, |
| 505 | + chunking_strategy=chunking_strategy, # type: ignore |
| 506 | + ) |
| 507 | + |
| 508 | + req = operations.PartitionRequest( |
| 509 | + partition_parameters=parameters |
| 510 | + ) |
| 511 | + |
| 512 | + resp = client.general.partition(request=req) |
| 513 | + assert len(resp.elements) == expected_elements_num |
| 514 | + assert all(element.type == "CompositeElement" for element in resp.elements) |
| 515 | + |
0 commit comments