Skip to content

Commit 14e561e

Browse files
committed
feat NEXUS-703: simplified asyncio operations
1 parent 130d960 commit 14e561e

File tree

1 file changed

+11
-31
lines changed

1 file changed

+11
-31
lines changed

src/unstructured_client/_hooks/custom/split_pdf_hook.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,33 +57,10 @@
5757
HI_RES_STRATEGY = 'hi_res'
5858
MAX_PAGE_LENGTH = 4000
5959

60-
def _get_asyncio_loop() -> asyncio.AbstractEventLoop:
61-
if sys.version_info < (3, 10):
62-
try:
63-
loop = asyncio.get_event_loop()
64-
except RuntimeError:
65-
loop = asyncio.new_event_loop()
66-
asyncio.set_event_loop(loop)
67-
else:
68-
try:
69-
loop = asyncio.get_running_loop()
70-
except RuntimeError:
71-
loop = asyncio.new_event_loop()
72-
asyncio.set_event_loop(loop)
73-
return loop
74-
7560
def _run_coroutines_in_separate_thread(
7661
coroutines_task: Coroutine[Any, Any, list[tuple[int, httpx.Response]]],
7762
) -> list[tuple[int, httpx.Response]]:
78-
loop = _get_asyncio_loop()
79-
return loop.run_until_complete(coroutines_task)
80-
81-
def _get_limiter(concurrency_level: int, executor: futures.ThreadPoolExecutor) -> asyncio.Semaphore:
82-
def _setup_limiter_in_thread_loop():
83-
_get_asyncio_loop()
84-
return asyncio.Semaphore(concurrency_level)
85-
return executor.submit(_setup_limiter_in_thread_loop).result()
86-
63+
return asyncio.run(coroutines_task)
8764

8865

8966
async def _order_keeper(index: int, coro: Awaitable) -> Tuple[int, httpx.Response]:
@@ -93,7 +70,8 @@ async def _order_keeper(index: int, coro: Awaitable) -> Tuple[int, httpx.Respons
9370

9471
async def run_tasks(
9572
coroutines: list[partial[Coroutine[Any, Any, httpx.Response]]],
96-
allow_failed: bool = False
73+
allow_failed: bool = False,
74+
concurrency_level: int = 10,
9775
) -> list[tuple[int, httpx.Response]]:
9876
"""Run a list of coroutines in parallel and return the results in order.
9977
@@ -109,13 +87,14 @@ async def run_tasks(
10987
# Use a variable to adjust the httpx client timeout, or default to 30 minutes
11088
# When we're able to reuse the SDK to make these calls, we can remove this var
11189
# The SDK timeout will be controlled by parameter
90+
limiter = asyncio.Semaphore(concurrency_level)
11291
client_timeout_minutes = 60
11392
if timeout_var := os.getenv("UNSTRUCTURED_CLIENT_TIMEOUT_MINUTES"):
11493
client_timeout_minutes = int(timeout_var)
11594
client_timeout = httpx.Timeout(60 * client_timeout_minutes)
11695

11796
async with httpx.AsyncClient(timeout=client_timeout) as client:
118-
armed_coroutines = [coro(async_client=client) for coro in coroutines] # type: ignore
97+
armed_coroutines = [coro(async_client=client, limiter=limiter) for coro in coroutines] # type: ignore
11998
if allow_failed:
12099
responses = await asyncio.gather(*armed_coroutines, return_exceptions=False)
121100
return list(enumerate(responses, 1))
@@ -192,6 +171,7 @@ def __init__(self) -> None:
192171
self.coroutines_to_execute: dict[
193172
str, list[partial[Coroutine[Any, Any, httpx.Response]]]
194173
] = {}
174+
self.concurrency_level: dict[str, int] = {}
195175
self.api_successful_responses: dict[str, list[httpx.Response]] = {}
196176
self.api_failed_responses: dict[str, list[httpx.Response]] = {}
197177
self.executors: dict[str, futures.ThreadPoolExecutor] = {}
@@ -341,7 +321,7 @@ def before_request(
341321
fallback_value=DEFAULT_ALLOW_FAILED,
342322
)
343323

344-
concurrency_level = form_utils.get_split_pdf_concurrency_level_param(
324+
self.concurrency_level[operation_id] = form_utils.get_split_pdf_concurrency_level_param(
345325
form_data,
346326
key=PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
347327
fallback_value=DEFAULT_CONCURRENCY_LEVEL,
@@ -350,7 +330,6 @@ def before_request(
350330

351331
executor = futures.ThreadPoolExecutor(max_workers=1)
352332
self.executors[operation_id] = executor
353-
limiter = _get_limiter(concurrency_level, executor)
354333

355334
self.cache_tmp_data_feature = form_utils.get_split_pdf_cache_tmp_data(
356335
form_data,
@@ -373,7 +352,7 @@ def before_request(
373352
page_count = page_range_end - page_range_start + 1
374353

375354
split_size = get_optimal_split_size(
376-
num_pages=page_count, concurrency_level=concurrency_level
355+
num_pages=page_count, concurrency_level=self.concurrency_level[operation_id]
377356
)
378357

379358
# If the doc is small enough, and we aren't slicing it with a page range:
@@ -416,7 +395,6 @@ def before_request(
416395
# in `after_success`.
417396
coroutine = partial(
418397
self.call_api_partial,
419-
limiter=limiter,
420398
operation_id=operation_id,
421399
pdf_chunk_request=pdf_chunk_request,
422400
pdf_chunk_file=pdf_chunk_file,
@@ -634,7 +612,8 @@ def _await_elements(self, operation_id: str) -> Optional[list]:
634612
if tasks is None:
635613
return None
636614

637-
coroutines = run_tasks(tasks, allow_failed=self.allow_failed)
615+
concurrency_level = self.concurrency_level.get(operation_id)
616+
coroutines = run_tasks(tasks, allow_failed=self.allow_failed, concurrency_level=concurrency_level)
638617

639618
# sending the coroutines to a separate thread to avoid blocking the current event loop
640619
# this operation should be removed when the SDK is updated to support async hooks
@@ -749,6 +728,7 @@ def _clear_operation(self, operation_id: str) -> None:
749728
"""
750729
self.coroutines_to_execute.pop(operation_id, None)
751730
self.api_successful_responses.pop(operation_id, None)
731+
self.concurrency_level.pop(operation_id, None)
752732
executor = self.executors.pop(operation_id, None)
753733
if executor is not None:
754734
executor.shutdown(wait=True)

0 commit comments

Comments
 (0)