Skip to content

Commit dcfb674

Browse files
authored
fix(DI): Ensure dependencies are cleaned up when exception occurrs during cleanup (#4148)
1 parent 0b867f4 commit dcfb674

File tree

8 files changed

+219
-89
lines changed

8 files changed

+219
-89
lines changed

docs/usage/dependency-injection.rst

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ and committing otherwise.
153153
.. admonition:: Best Practice
154154
:class: tip
155155

156-
You should always wrap `yield` in a `try`/`finally` block, regardless of whether you
157-
want to handle exceptions, to ensure that the cleanup code is run even when exceptions
158-
occurred:
156+
You should always wrap ``yield`` in a ``try``/``finally`` block, regardless of
157+
whether you want to handle exceptions, to ensure that the cleanup code is run even
158+
when exceptions occurred:
159159

160160
.. code-block:: python
161161
@@ -168,9 +168,19 @@ and committing otherwise.
168168
169169
.. attention::
170170

171-
Do not re-raise exceptions within the dependency. Exceptions caught within these
172-
dependencies will still be handled by the regular mechanisms without an explicit
173-
re-raise
171+
Do not re-raise exceptions within the dependency. Exceptions caught within these
172+
dependencies will still be handled by the regular mechanisms without an explicit
173+
re-raise
174+
175+
176+
.. important::
177+
178+
Exceptions raised during the cleanup step of a dependency will be re-raised in an
179+
:exc:`ExceptionGroup` (for Python versions < 3.11, the
180+
`exceptiongroup <https://github.com/agronholm/exceptiongroup>`_ will be used). This
181+
happens after all dependencies have been cleaned up, so exceptions raised during
182+
cleanup of one dependencies do not affect the cleanup of other dependencies.
183+
174184

175185

176186
Dependency keyword arguments

litestar/_kwargs/cleanup.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from __future__ import annotations
22

3-
from inspect import Traceback, isasyncgen
3+
import sys
4+
from contextlib import AbstractAsyncContextManager
5+
from inspect import isasyncgen
46
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator
57

8+
if sys.version_info < (3, 11):
9+
from exceptiongroup import ExceptionGroup
10+
611
from anyio import create_task_group
712

813
from litestar.utils import ensure_async_callable
@@ -12,10 +17,12 @@
1217

1318

1419
if TYPE_CHECKING:
20+
from types import TracebackType
21+
1522
from litestar.types import AnyGenerator
1623

1724

18-
class DependencyCleanupGroup:
25+
class DependencyCleanupGroup(AbstractAsyncContextManager):
1926
"""Wrapper for generator based dependencies.
2027
2128
Simplify cleanup by wrapping :func:`next` / :func:`anext` calls and providing facilities to
@@ -24,8 +31,6 @@ class DependencyCleanupGroup:
2431
exceptions caught in this manner will be re-raised after they have been thrown in the generators.
2532
"""
2633

27-
__slots__ = ("_closed", "_generators")
28-
2934
def __init__(self, generators: list[AnyGenerator] | None = None) -> None:
3035
"""Initialize ``DependencyCleanupGroup``.
3136
@@ -45,7 +50,7 @@ def add(self, generator: Generator[Any, None, None] | AsyncGenerator[Any, None])
4550
None
4651
"""
4752
if self._closed:
48-
raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup")
53+
raise RuntimeError("Cannot call .add on a closed DependencyCleanupGroup")
4954
self._generators.append(generator)
5055

5156
@staticmethod
@@ -62,18 +67,25 @@ def wrapped() -> None:
6267

6368
return ensure_async_callable(wrapped)
6469

65-
async def cleanup(self) -> None:
70+
async def close(self, exc: BaseException | None = None) -> None:
71+
if self._closed:
72+
raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup")
73+
74+
self._closed = True
75+
76+
if exc is None:
77+
await self._cleanup()
78+
else:
79+
await self._throw(exc)
80+
81+
async def _cleanup(self) -> None:
6682
"""Execute cleanup by calling :func:`next` / :func:`anext` on all generators.
6783
6884
If there are multiple generators to be called, they will be executed in a :class:`anyio.TaskGroup`.
6985
7086
Returns:
7187
None
7288
"""
73-
if self._closed:
74-
raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup")
75-
76-
self._closed = True
7789

7890
if not self._generators:
7991
return
@@ -95,18 +107,18 @@ async def __aexit__(
95107
self,
96108
exc_type: type[BaseException] | None,
97109
exc_val: BaseException | None,
98-
exc_tb: Traceback | None,
110+
exc_tb: TracebackType | None,
99111
) -> None:
100112
"""If an exception was raised within the contextmanager block, throw it into all generators."""
101-
if exc_val:
102-
await self.throw(exc_val)
113+
await self.close(exc_val)
103114

104-
async def throw(self, exc: BaseException) -> None:
115+
async def _throw(self, exc: BaseException) -> None:
105116
"""Throw an exception in all generators sequentially.
106117
107118
Args:
108119
exc: Exception to throw
109120
"""
121+
exceptions = []
110122
for gen in self._generators:
111123
try:
112124
if isasyncgen(gen):
@@ -115,3 +127,12 @@ async def throw(self, exc: BaseException) -> None:
115127
gen.throw(exc) # type: ignore[union-attr]
116128
except (StopIteration, StopAsyncIteration):
117129
continue
130+
except Exception as cleanup_exc: # noqa: BLE001
131+
if cleanup_exc is not exc:
132+
exceptions.append(cleanup_exc)
133+
134+
if exceptions:
135+
raise ExceptionGroup(
136+
"Exceptions occurred during cleanup of dependencies",
137+
exceptions,
138+
) from exc

litestar/routes/http.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
from itertools import chain
45
from typing import TYPE_CHECKING, Any
56

@@ -17,7 +18,6 @@
1718

1819
if TYPE_CHECKING:
1920
from litestar._kwargs import KwargsModel
20-
from litestar._kwargs.cleanup import DependencyCleanupGroup
2121
from litestar.connection import Request
2222
from litestar.types import ASGIApp, HTTPScope, Method, Receive, Scope, Send
2323

@@ -134,72 +134,54 @@ async def _get_response_for_request(
134134
scope=scope, request=request, parameter_model=parameter_model, route_handler=route_handler
135135
)
136136

137-
async def _call_handler_function(
137+
async def _call_handler_function( # type: ignore[return]
138138
self, scope: Scope, request: Request, parameter_model: KwargsModel, route_handler: HTTPRouteHandler
139-
) -> ASGIApp:
139+
) -> ASGIApp: # pyright: ignore[reportGeneralTypeIssues]
140140
"""Call the before request handlers, retrieve any data required for the route handler, and call the route
141141
handler's ``to_response`` method.
142142
143143
This is wrapped in a try except block - and if an exception is raised,
144144
it tries to pass it to an appropriate exception handler - if defined.
145145
"""
146146
response_data: Any = None
147-
cleanup_group: DependencyCleanupGroup | None = None
148147

149148
if before_request_handler := route_handler.resolve_before_request():
150149
response_data = await before_request_handler(request)
151150

152-
if not response_data:
153-
response_data, cleanup_group = await self._get_response_data(
154-
route_handler=route_handler, parameter_model=parameter_model, request=request
155-
)
151+
# create and enter an AsyncExit stack as we may or may not have a
152+
# 'DependencyCleanupGroup' to enter and exit
153+
stack = contextlib.AsyncExitStack()
156154

157-
response: ASGIApp = await route_handler.to_response(
158-
app=scope["litestar_app"], data=response_data, request=request
159-
)
155+
# mypy cannot infer that 'stack' never swallows exceptions, therefore it thinks
156+
# this method is potentially missing a 'return' statement
157+
async with stack:
158+
if not response_data:
159+
parsed_kwargs: dict[str, Any] = {}
160160

161-
if cleanup_group:
162-
await cleanup_group.cleanup()
161+
if parameter_model.has_kwargs and route_handler.signature_model:
162+
try:
163+
kwargs = await parameter_model.to_kwargs(connection=request)
164+
except SerializationException as e:
165+
raise ClientException(str(e)) from e
163166

164-
return response
167+
if kwargs.get("data") is Empty:
168+
del kwargs["data"]
165169

166-
@staticmethod
167-
async def _get_response_data(
168-
route_handler: HTTPRouteHandler, parameter_model: KwargsModel, request: Request
169-
) -> tuple[Any, DependencyCleanupGroup | None]:
170-
"""Determine what kwargs are required for the given route handler's ``fn`` and calls it."""
171-
parsed_kwargs: dict[str, Any] = {}
172-
cleanup_group: DependencyCleanupGroup | None = None
173-
174-
if parameter_model.has_kwargs and route_handler.signature_model:
175-
try:
176-
kwargs = await parameter_model.to_kwargs(connection=request)
177-
except SerializationException as e:
178-
raise ClientException(str(e)) from e
179-
180-
if kwargs.get("data") is Empty:
181-
del kwargs["data"]
182-
183-
if parameter_model.dependency_batches:
184-
cleanup_group = await parameter_model.resolve_dependencies(request, kwargs)
185-
186-
parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs(
187-
connection=request, kwargs=kwargs
188-
)
170+
if parameter_model.dependency_batches:
171+
cleanup_group = await parameter_model.resolve_dependencies(request, kwargs)
172+
await stack.enter_async_context(cleanup_group)
173+
174+
parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs(
175+
connection=request, kwargs=kwargs
176+
)
189177

190-
if cleanup_group:
191-
async with cleanup_group:
192-
data = (
178+
response_data = (
193179
route_handler.fn(**parsed_kwargs)
194180
if route_handler.has_sync_callable
195181
else await route_handler.fn(**parsed_kwargs)
196182
)
197-
elif route_handler.has_sync_callable:
198-
data = route_handler.fn(**parsed_kwargs)
199-
else:
200-
data = await route_handler.fn(**parsed_kwargs)
201183

202-
return data, cleanup_group
184+
return await route_handler.to_response(app=scope["litestar_app"], data=response_data, request=request)
203185

204186
@staticmethod
205187
async def _get_cached_response(request: Request, route_handler: HTTPRouteHandler) -> ASGIApp | None:

litestar/routes/websocket.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,5 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N
8181
if cleanup_group:
8282
async with cleanup_group:
8383
await self.route_handler.fn(**parsed_kwargs)
84-
await cleanup_group.cleanup()
8584
else:
8685
await self.route_handler.fn(**parsed_kwargs)

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies = [
4646
"rich>=13.0.0",
4747
"rich-click",
4848
"multipart>=1.2.0",
49+
"exceptiongroup>=1.2.2; python_version < \"3.11\"",
4950
# default litestar plugins
5051
"litestar-htmx>=0.4.0",
5152
]
@@ -397,8 +398,6 @@ lint.select = [
397398
"W", # pycodestyle - warning
398399
"YTT", # flake8-2020
399400
]
400-
401-
line-length = 120
402401
lint.ignore = [
403402
"A003", # flake8-builtins - class attribute {name} is shadowing a python builtin
404403
"B010", # flake8-bugbear - do not call setattr with a constant attribute value
@@ -419,6 +418,8 @@ lint.ignore = [
419418
"ISC001", # Ruff formatter incompatible
420419
"CPY001", # ruff - copyright notice at the top of the file
421420
]
421+
422+
line-length = 120
422423
src = ["litestar", "tests", "docs/examples"]
423424
target-version = "py38"
424425

0 commit comments

Comments
 (0)