Skip to content

Commit 237397e

Browse files
committed
test: added/updated tests
1 parent 004d576 commit 237397e

File tree

2 files changed

+109
-59
lines changed

2 files changed

+109
-59
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Get unit tests for request_utils.py module
2+
import httpx
3+
import pytest
4+
5+
from unstructured_client._hooks.custom.request_utils import create_pdf_chunk_request_params, get_multipart_stream_fields
6+
from unstructured_client.models import shared
7+
8+
9+
# make the above test using @pytest.mark.parametrize
10+
@pytest.mark.parametrize(("input_request", "expected"), [
11+
(httpx.Request("POST", "http://localhost:8000", data={}, headers={"Content-Type": "multipart/form-data"}), {}),
12+
(httpx.Request("POST", "http://localhost:8000", data={"hello": "world"}, headers={"Content-Type": "application/json"}), {}),
13+
(httpx.Request(
14+
"POST",
15+
"http://localhost:8000",
16+
data={"hello": "world"},
17+
files={"files": ("hello.pdf", b"hello", "application/pdf")},
18+
headers={"Content-Type": "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"}),
19+
{
20+
"hello": "world",
21+
"files": {
22+
"content_type":"application/pdf",
23+
"filename": "hello.pdf",
24+
"file": b"hello",
25+
}
26+
}
27+
),
28+
])
29+
def test_get_multipart_stream_fields(input_request, expected):
30+
fields = get_multipart_stream_fields(input_request)
31+
assert fields == expected
32+
33+
def test_multipart_stream_fields_raises_value_error_when_filename_is_not_set():
34+
with pytest.raises(ValueError):
35+
get_multipart_stream_fields(httpx.Request(
36+
"POST",
37+
"http://localhost:8000",
38+
data={"hello": "world"},
39+
files={"files": ("", b"hello", "application/pdf")},
40+
headers={"Content-Type": "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"}),
41+
)
42+
43+
@pytest.mark.parametrize(("input_form_data", "page_number", "expected_form_data"), [
44+
(
45+
{"hello": "world"},
46+
2,
47+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "2"}
48+
),
49+
(
50+
{"hello": "world", "split_pdf_page": "true"},
51+
2,
52+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "2"}
53+
),
54+
(
55+
{"hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
56+
3,
57+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "3"}
58+
),
59+
(
60+
{"split_pdf_page_range[]": [1, 3], "hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
61+
3,
62+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "3"}
63+
),
64+
(
65+
{"split_pdf_page_range": [1, 3], "hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
66+
4,
67+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "4"}
68+
),
69+
])
70+
def test_create_pdf_chunk_request_params(input_form_data, page_number, expected_form_data):
71+
form_data = create_pdf_chunk_request_params(input_form_data, page_number)
72+
assert form_data == expected_form_data

_test_unstructured_client/unit/test_split_pdf_hook.py

Lines changed: 37 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
from asyncio import Task
77
from collections import Counter
8+
from functools import partial
89
from typing import Coroutine
910

1011
import httpx
@@ -201,61 +202,31 @@ def test_unit_parse_form_data_none_filename_error():
201202
form_utils.parse_form_data(decoded_data)
202203

203204

204-
def test_unit_is_pdf_valid_pdf():
205-
"""Test is pdf method returns True for valid pdf file with filename."""
205+
def test_unit_is_pdf_valid_pdf_when_passing_file_object():
206+
"""Test is pdf method returns pdf object for valid pdf file with filename."""
206207
filename = "_sample_docs/layout-parser-paper-fast.pdf"
207208

208209
with open(filename, "rb") as f:
209-
file = shared.Files(
210-
content=f.read(),
211-
file_name=filename,
212-
)
213-
214-
result = pdf_utils.read_pdf(file)
210+
result = pdf_utils.read_pdf(f)
215211

216212
assert result is not None
217213

218214

219-
def test_unit_is_pdf_valid_pdf_without_file_extension():
220-
"""Test is pdf method returns True for file with valid pdf content without basing on file extension."""
215+
def test_unit_is_pdf_valid_pdf_when_passing_binary_content():
216+
"""Test is pdf method returns pdf object for file with valid pdf content"""
221217
filename = "_sample_docs/layout-parser-paper-fast.pdf"
222218

223219
with open(filename, "rb") as f:
224-
file = shared.Files(
225-
content=f.read(),
226-
file_name="uuid1234",
227-
)
228-
229-
result = pdf_utils.read_pdf(file)
220+
result = pdf_utils.read_pdf(f.read())
230221

231222
assert result is not None
232223

233224

234-
def test_unit_is_pdf_invalid_extension():
235-
"""Test is pdf method returns False for file with invalid extension."""
236-
file = shared.Files(content=b"txt_content", file_name="test_file.txt")
237-
238-
result = pdf_utils.read_pdf(file)
239-
240-
assert result is None
241-
242-
243225
def test_unit_is_pdf_invalid_pdf():
244-
"""Test is pdf method returns False for file with invalid pdf content."""
245-
file = shared.Files(content=b"invalid_pdf_content", file_name="test_file.pdf")
246-
247-
result = pdf_utils.read_pdf(file)
226+
"""Test is pdf method returns False for file with invalid extension."""
227+
result = pdf_utils.read_pdf(b"txt_content")
248228

249229
assert result is None
250-
251-
252-
def test_unit_is_pdf_invalid_pdf_without_file_extension():
253-
"""Test is pdf method returns False for file with invalid pdf content without basing on file extension."""
254-
file = shared.Files(content=b"invalid_pdf_content", file_name="uuid1234")
255-
256-
result = pdf_utils.read_pdf(file)
257-
258-
assert result is not None
259230

260231

261232
def test_unit_get_starting_page_number_missing_key():
@@ -365,7 +336,10 @@ def test_unit_get_page_range_returns_valid_range(page_range, expected_result):
365336
assert result == expected_result
366337

367338

368-
async def _request_mock(fails: bool, content: str) -> requests.Response:
339+
async def _request_mock(
340+
async_client: httpx.AsyncClient, # not used by mock
341+
fails: bool,
342+
content: str) -> requests.Response:
369343
response = requests.Response()
370344
response.status_code = 500 if fails else 200
371345
response._content = content.encode()
@@ -376,40 +350,40 @@ async def _request_mock(fails: bool, content: str) -> requests.Response:
376350
("allow_failed", "tasks", "expected_responses"), [
377351
pytest.param(
378352
True, [
379-
_request_mock(fails=False, content="1"),
380-
_request_mock(fails=False, content="2"),
381-
_request_mock(fails=False, content="3"),
382-
_request_mock(fails=False, content="4"),
353+
partial(_request_mock, fails=False, content="1"),
354+
partial(_request_mock, fails=False, content="2"),
355+
partial(_request_mock, fails=False, content="3"),
356+
partial(_request_mock, fails=False, content="4"),
383357
],
384358
["1", "2", "3", "4"],
385359
id="no failures, fails allower"
386360
),
387361
pytest.param(
388362
True, [
389-
_request_mock(fails=False, content="1"),
390-
_request_mock(fails=True, content="2"),
391-
_request_mock(fails=False, content="3"),
392-
_request_mock(fails=True, content="4"),
363+
partial(_request_mock, fails=False, content="1"),
364+
partial(_request_mock, fails=True, content="2"),
365+
partial(_request_mock, fails=False, content="3"),
366+
partial(_request_mock, fails=True, content="4"),
393367
],
394368
["1", "2", "3", "4"],
395369
id="failures, fails allowed"
396370
),
397371
pytest.param(
398372
False, [
399-
_request_mock(fails=True, content="failure"),
400-
_request_mock(fails=False, content="2"),
401-
_request_mock(fails=True, content="failure"),
402-
_request_mock(fails=False, content="4"),
373+
partial(_request_mock, fails=True, content="failure"),
374+
partial(_request_mock, fails=False, content="2"),
375+
partial(_request_mock, fails=True, content="failure"),
376+
partial(_request_mock, fails=False, content="4"),
403377
],
404378
["failure"],
405379
id="failures, fails disallowed"
406380
),
407381
pytest.param(
408382
False, [
409-
_request_mock(fails=False, content="1"),
410-
_request_mock(fails=False, content="2"),
411-
_request_mock(fails=False, content="3"),
412-
_request_mock(fails=False, content="4"),
383+
partial(_request_mock, fails=False, content="1"),
384+
partial(_request_mock, fails=False, content="2"),
385+
partial(_request_mock, fails=False, content="3"),
386+
partial(_request_mock, fails=False, content="4"),
413387
],
414388
["1", "2", "3", "4"],
415389
id="no failures, fails disallowed"
@@ -428,14 +402,18 @@ async def test_unit_disallow_failed_coroutines(
428402
assert response_contents == expected_responses
429403

430404

431-
async def _fetch_canceller_error(fails: bool, content: str, cancelled_counter: Counter):
405+
async def _fetch_canceller_error(
406+
async_client: httpx.AsyncClient, # not used by mock
407+
fails: bool,
408+
content: str,
409+
cancelled_counter: Counter):
432410
try:
433411
if not fails:
434412
await asyncio.sleep(0.01)
435413
print("Doesn't fail")
436414
else:
437415
print("Fails")
438-
return await _request_mock(fails=fails, content=content)
416+
return await _request_mock(async_client=async_client, fails=fails, content=content)
439417
except asyncio.CancelledError:
440418
cancelled_counter.update(["cancelled"])
441419
print(cancelled_counter["cancelled"])
@@ -446,8 +424,8 @@ async def _fetch_canceller_error(fails: bool, content: str, cancelled_counter: C
446424
async def test_remaining_tasks_cancelled_when_fails_disallowed():
447425
cancelled_counter = Counter()
448426
tasks = [
449-
_fetch_canceller_error(fails=True, content="1", cancelled_counter=cancelled_counter),
450-
*[_fetch_canceller_error(fails=False, content=f"{i}", cancelled_counter=cancelled_counter)
427+
partial(_fetch_canceller_error, fails=True, content="1", cancelled_counter=cancelled_counter),
428+
*[partial(_fetch_canceller_error, fails=False, content=f"{i}", cancelled_counter=cancelled_counter)
451429
for i in range(2, 200)],
452430
]
453431

0 commit comments

Comments
 (0)