|
1 | 1 | import builtins
|
2 | 2 | import contextlib
|
3 | 3 | import functools
|
| 4 | +import inspect |
4 | 5 | import logging
|
5 | 6 | from asyncio import iscoroutinefunction
|
6 | 7 | from collections.abc import AsyncGenerator, Callable, Iterator
|
@@ -60,30 +61,59 @@ def omit_exception(method: Callable | None = None, return_value: Any | None = No
|
60 | 61 | if method is None:
|
61 | 62 | return functools.partial(omit_exception, return_value=return_value)
|
62 | 63 |
|
| 64 | + def __handle_error(self, e) -> Any | None: |
| 65 | + if getattr(self, "_ignore_exceptions", None): |
| 66 | + if getattr(self, "_log_ignored_exceptions", None): |
| 67 | + self.logger.exception("Exception ignored") |
| 68 | + |
| 69 | + return return_value |
| 70 | + raise e.__cause__ |
| 71 | + |
63 | 72 | @functools.wraps(method)
|
64 | 73 | def _decorator(self, *args, **kwargs):
|
65 | 74 | try:
|
66 | 75 | return method(self, *args, **kwargs)
|
67 | 76 | except ConnectionInterrupted as e:
|
68 |
| - if self._ignore_exceptions: |
69 |
| - if self._log_ignored_exceptions: |
70 |
| - self.logger.exception("Exception ignored") |
| 77 | + return __handle_error(self, e) |
71 | 78 |
|
72 |
| - return return_value |
73 |
| - raise e.__cause__ # noqa: B904 |
| 79 | + @functools.wraps(method) |
| 80 | + def _generator_decorator(self, *args, **kwargs): |
| 81 | + try: |
| 82 | + for item in method(self, *args, **kwargs): |
| 83 | + yield item |
| 84 | + except ConnectionInterrupted as e: |
| 85 | + yield __handle_error(self, e) |
74 | 86 |
|
75 | 87 | @functools.wraps(method)
|
76 | 88 | async def _async_decorator(self, *args, **kwargs):
|
77 | 89 | try:
|
78 | 90 | return await method(self, *args, **kwargs)
|
79 | 91 | except ConnectionInterrupted as e:
|
80 |
| - if self._ignore_exceptions: |
81 |
| - if self._log_ignored_exceptions: |
82 |
| - self.logger.exception("Exception ignored") |
83 |
| - return return_value |
84 |
| - raise e.__cause__ |
| 92 | + return __handle_error(self, e) |
| 93 | + |
| 94 | + @functools.wraps(method) |
| 95 | + async def _async_generator_decorator(self, *args, **kwargs): |
| 96 | + try: |
| 97 | + async for item in method(self, *args, **kwargs): |
| 98 | + yield item |
| 99 | + except ConnectionInterrupted as e: |
| 100 | + yield __handle_error(self, e) |
| 101 | + |
| 102 | + # inspect.isfunction returns true for generators, so can't use that to check this |
| 103 | + if not inspect.isasyncgenfunction(method) and not inspect.isgeneratorfunction( |
| 104 | + method |
| 105 | + ): |
| 106 | + wrapper = _async_decorator if iscoroutinefunction(method) else _decorator |
| 107 | + |
| 108 | + # if method is a generator or async generator, it should be iterated over by this decorator |
| 109 | + # generators don't error by simply being called, they need to be iterated over. |
| 110 | + else: |
| 111 | + wrapper = ( |
| 112 | + _async_generator_decorator |
| 113 | + if inspect.isasyncgenfunction(method) |
| 114 | + else _generator_decorator |
| 115 | + ) |
85 | 116 |
|
86 |
| - wrapper = _async_decorator if iscoroutinefunction(method) else _decorator |
87 | 117 | wrapper.original = method
|
88 | 118 | return wrapper
|
89 | 119 |
|
|
0 commit comments