Skip to content

Commit cc09104

Browse files
use context manager to run the timeout background tasks
1 parent 12762fb commit cc09104

File tree

3 files changed

+76
-46
lines changed

3 files changed

+76
-46
lines changed

elasticsearch/_async/helpers.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Union,
3434
)
3535

36+
from ..compat import safe_task
3637
from ..exceptions import ApiError, NotFoundError, TransportError
3738
from ..helpers.actions import (
3839
_TYPE_BULK_ACTION,
@@ -94,28 +95,24 @@ async def get_items() -> None:
9495
try:
9596
async for item in actions:
9697
await item_queue.put(item)
97-
except Exception:
98+
finally:
9899
await item_queue.put((BulkMeta.done, None))
99-
raise
100-
await item_queue.put((BulkMeta.done, None))
101100

102-
item_getter_job = asyncio.create_task(get_items())
103-
104-
timeout: Optional[float] = flush_after_seconds
105-
while True:
106-
try:
107-
action, data = await asyncio.wait_for(item_queue.get(), timeout=timeout)
108-
timeout = flush_after_seconds
109-
except asyncio.TimeoutError:
110-
action, data = BulkMeta.flush, None
111-
timeout = None
112-
113-
if action is BulkMeta.done:
114-
break
115-
ret = chunker.feed(action, data)
116-
if ret:
117-
yield ret
118-
await item_getter_job
101+
async with safe_task(get_items()):
102+
timeout: Optional[float] = flush_after_seconds
103+
while True:
104+
try:
105+
action, data = await asyncio.wait_for(item_queue.get(), timeout=timeout)
106+
timeout = flush_after_seconds
107+
except asyncio.TimeoutError:
108+
action, data = BulkMeta.flush, None
109+
timeout = None
110+
111+
if action is BulkMeta.done:
112+
break
113+
ret = chunker.feed(action, data)
114+
if ret:
115+
yield ret
119116

120117
ret = chunker.flush()
121118
if ret:

elasticsearch/compat.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import asyncio
19+
from contextlib import contextmanager, asynccontextmanager
1820
import inspect
1921
import os
2022
import sys
2123
from pathlib import Path
24+
from threading import Thread
2225
from typing import Tuple, Type, Union
2326

2427
string_types: Tuple[Type[str], Type[bytes]] = (str, bytes)
@@ -76,9 +79,46 @@ def warn_stacklevel() -> int:
7679
return 0
7780

7881

82+
@contextmanager
83+
def safe_thread(target, *args, **kwargs):
84+
"""Run a thread within a context manager block.
85+
86+
The thread is automatically joined when the block ends. If the thread raised
87+
an exception, it is raised in the caller's context.
88+
"""
89+
captured_exception = None
90+
91+
def run():
92+
try:
93+
target(*args, **kwargs)
94+
except BaseException as exc:
95+
nonlocal captured_exception
96+
captured_exception = exc
97+
98+
thread = Thread(target=run)
99+
thread.start()
100+
yield
101+
thread.join()
102+
if captured_exception:
103+
raise captured_exception
104+
105+
106+
@asynccontextmanager
107+
async def safe_task(coro):
108+
"""Run a background task within a context manager block.
109+
110+
The task is awaited when the block ends.
111+
"""
112+
task = asyncio.create_task(coro)
113+
yield task
114+
await task
115+
116+
79117
__all__ = [
80118
"string_types",
81119
"to_str",
82120
"to_bytes",
83121
"warn_stacklevel",
122+
"safe_thread",
123+
"safe_task",
84124
]

elasticsearch/helpers/actions.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from elastic_transport import OpenTelemetrySpan
4040

4141
from .. import Elasticsearch
42-
from ..compat import to_bytes
42+
from ..compat import to_bytes, safe_thread
4343
from ..exceptions import ApiError, NotFoundError, TransportError
4444
from ..serializer import Serializer
4545
from .errors import BulkIndexError, ScanError
@@ -266,35 +266,28 @@ def _chunk_actions(
266266
)
267267

268268
def get_items() -> None:
269-
ret = None
270269
try:
271270
for item in actions:
272271
item_queue.put(item)
273-
except BaseException as exc:
274-
ret = exc
275-
item_queue.put((BulkMeta.done, ret))
272+
finally:
273+
# make sure we signal the end even if there is an exception
274+
item_queue.put((BulkMeta.done, None))
276275

277-
item_getter_job = Thread(target=get_items)
278-
item_getter_job.start()
279-
280-
timeout: Optional[float] = flush_after_seconds
281-
while True:
282-
try:
283-
action, data = item_queue.get(timeout=timeout)
284-
timeout = flush_after_seconds
285-
except queue.Empty:
286-
action, data = BulkMeta.flush, None
287-
timeout = None
288-
289-
if action is BulkMeta.done:
290-
if isinstance(data, BaseException):
291-
raise data
292-
break
293-
ret = chunker.feed(action, data)
294-
if ret:
295-
yield ret
296-
297-
item_getter_job.join()
276+
with safe_thread(get_items):
277+
timeout: Optional[float] = flush_after_seconds
278+
while True:
279+
try:
280+
action, data = item_queue.get(timeout=timeout)
281+
timeout = flush_after_seconds
282+
except queue.Empty:
283+
action, data = BulkMeta.flush, None
284+
timeout = None
285+
286+
if action is BulkMeta.done:
287+
break
288+
ret = chunker.feed(action, data)
289+
if ret:
290+
yield ret
298291

299292
ret = chunker.flush()
300293
if ret:

0 commit comments

Comments
 (0)