Skip to content

Commit 4a0ba88

Browse files
authored
feat: add disable_di flag to disable built in framework DI (#222)
Adds `disable_di` boolean flag to Litestar, Starlette/FastAPI, and Flask extensions. When enabled, disables built-in dependency injection to allow users to manage database lifecycle with their own DI solutions (e.g., Dishka).
1 parent d59d561 commit 4a0ba88

File tree

17 files changed

+419
-12
lines changed

17 files changed

+419
-12
lines changed

AGENTS.md

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2616,6 +2616,138 @@ class StarletteConfig(TypedDict):
26162616
extra_rollback_statuses: NotRequired[set[int]]
26172617
```
26182618

2619+
### Disabling Built-in Dependency Injection (disable_di Pattern)
2620+
2621+
**When to Use**: When users want to integrate SQLSpec with their own dependency injection solution (e.g., Dishka, dependency-injector) and need full control over database lifecycle management.
2622+
2623+
**Pattern**: Add a `disable_di` boolean flag to framework extension configuration that conditionally skips the built-in DI setup.
2624+
2625+
**Implementation Steps**:
2626+
2627+
1. **Add to TypedDict in `sqlspec/config.py`**:
2628+
2629+
```python
2630+
class StarletteConfig(TypedDict):
2631+
# ... existing fields ...
2632+
2633+
disable_di: NotRequired[bool]
2634+
"""Disable built-in dependency injection. Default: False.
2635+
When True, the Starlette/FastAPI extension will not add middleware for managing
2636+
database connections and sessions. Users are responsible for managing the
2637+
database lifecycle manually via their own DI solution.
2638+
"""
2639+
```
2640+
2641+
2. **Add to Configuration State Dataclass**:
2642+
2643+
```python
2644+
@dataclass
2645+
class SQLSpecConfigState:
2646+
config: "DatabaseConfigProtocol[Any, Any, Any]"
2647+
connection_key: str
2648+
pool_key: str
2649+
session_key: str
2650+
commit_mode: CommitMode
2651+
extra_commit_statuses: "set[int] | None"
2652+
extra_rollback_statuses: "set[int] | None"
2653+
disable_di: bool # Add this field
2654+
```
2655+
2656+
3. **Extract from Config and Default to False**:
2657+
2658+
```python
2659+
def _extract_starlette_settings(self, config):
2660+
starlette_config = config.extension_config.get("starlette", {})
2661+
return {
2662+
# ... existing keys ...
2663+
"disable_di": starlette_config.get("disable_di", False), # Default False
2664+
}
2665+
```
2666+
2667+
4. **Conditionally Skip DI Setup**:
2668+
2669+
**Middleware-based (Starlette/FastAPI)**:
2670+
```python
2671+
def init_app(self, app):
2672+
# ... lifespan setup ...
2673+
2674+
for config_state in self._config_states:
2675+
if not config_state.disable_di: # Only add if DI enabled
2676+
self._add_middleware(app, config_state)
2677+
```
2678+
2679+
**Provider-based (Litestar)**:
2680+
```python
2681+
def on_app_init(self, app_config):
2682+
for state in self._plugin_configs:
2683+
# ... signature namespace ...
2684+
2685+
if not state.disable_di: # Only register if DI enabled
2686+
app_config.before_send.append(state.before_send_handler)
2687+
app_config.lifespan.append(state.lifespan_handler)
2688+
app_config.dependencies.update({
2689+
state.connection_key: Provide(state.connection_provider),
2690+
state.pool_key: Provide(state.pool_provider),
2691+
state.session_key: Provide(state.session_provider),
2692+
})
2693+
```
2694+
2695+
**Hook-based (Flask)**:
2696+
```python
2697+
def init_app(self, app):
2698+
# ... pool setup ...
2699+
2700+
# Only register hooks if at least one config has DI enabled
2701+
if any(not state.disable_di for state in self._config_states):
2702+
app.before_request(self._before_request_handler)
2703+
app.after_request(self._after_request_handler)
2704+
app.teardown_appcontext(self._teardown_appcontext_handler)
2705+
2706+
def _before_request_handler(self):
2707+
for config_state in self._config_states:
2708+
if config_state.disable_di: # Skip if DI disabled
2709+
continue
2710+
# ... connection setup ...
2711+
```
2712+
2713+
**Testing Requirements**:
2714+
2715+
1. **Test with `disable_di=True`**: Verify DI mechanisms are not active
2716+
2. **Test default behavior**: Verify `disable_di=False` preserves existing functionality
2717+
3. **Integration tests**: Demonstrate manual DI setup works correctly
2718+
2719+
**Example Usage**:
2720+
2721+
```python
2722+
from sqlspec.adapters.asyncpg import AsyncpgConfig
2723+
from sqlspec.base import SQLSpec
2724+
from sqlspec.extensions.starlette import SQLSpecPlugin
2725+
2726+
sql = SQLSpec()
2727+
config = AsyncpgConfig(
2728+
pool_config={"dsn": "postgresql://localhost/db"},
2729+
extension_config={"starlette": {"disable_di": True}} # Disable built-in DI
2730+
)
2731+
sql.add_config(config)
2732+
plugin = SQLSpecPlugin(sql)
2733+
2734+
# User is now responsible for manual lifecycle management
2735+
async def my_route(request):
2736+
pool = await config.create_pool()
2737+
async with config.provide_connection(pool) as connection:
2738+
session = config.driver_type(connection=connection, statement_config=config.statement_config)
2739+
result = await session.execute("SELECT 1")
2740+
await config.close_pool()
2741+
return result
2742+
```
2743+
2744+
**Key Principles**:
2745+
2746+
- **Backward Compatible**: Default `False` preserves existing behavior
2747+
- **Consistent Naming**: Use `disable_di` across all frameworks
2748+
- **Clear Documentation**: Warn users they are responsible for lifecycle management
2749+
- **Complete Control**: When disabled, extension does zero automatic DI
2750+
26192751
### Multi-Database Support
26202752

26212753
**Key validation ensures unique state keys**:

sqlspec/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ class FlaskConfig(TypedDict):
142142
extra_rollback_statuses: NotRequired[set[int]]
143143
"""Additional HTTP status codes that trigger rollback. Default: None."""
144144

145+
disable_di: NotRequired[bool]
146+
"""Disable built-in dependency injection. Default: False.
147+
When True, the Flask extension will not register request hooks for managing
148+
database connections and sessions. Users are responsible for managing the
149+
database lifecycle manually via their own DI solution.
150+
"""
151+
145152

146153
class LitestarConfig(TypedDict):
147154
"""Configuration options for Litestar SQLSpec plugin.
@@ -170,6 +177,13 @@ class LitestarConfig(TypedDict):
170177
extra_rollback_statuses: NotRequired[set[int]]
171178
"""Additional HTTP status codes that trigger rollback. Default: set()"""
172179

180+
disable_di: NotRequired[bool]
181+
"""Disable built-in dependency injection. Default: False.
182+
When True, the Litestar plugin will not register dependency providers for managing
183+
database connections, pools, and sessions. Users are responsible for managing the
184+
database lifecycle manually via their own DI solution.
185+
"""
186+
173187

174188
class StarletteConfig(TypedDict):
175189
"""Configuration options for Starlette and FastAPI extensions.
@@ -225,6 +239,13 @@ class StarletteConfig(TypedDict):
225239
extra_rollback_statuses={409}
226240
"""
227241

242+
disable_di: NotRequired[bool]
243+
"""Disable built-in dependency injection. Default: False.
244+
When True, the Starlette/FastAPI extension will not add middleware for managing
245+
database connections and sessions. Users are responsible for managing the
246+
database lifecycle manually via their own DI solution.
247+
"""
248+
228249

229250
class FastAPIConfig(StarletteConfig):
230251
"""Configuration options for FastAPI SQLSpec extension.

sqlspec/extensions/flask/_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class FlaskConfigState:
2727
extra_commit_statuses: "set[int] | None"
2828
extra_rollback_statuses: "set[int] | None"
2929
is_async: bool
30+
disable_di: bool
3031

3132
def should_commit(self, status_code: int) -> bool:
3233
"""Determine if HTTP status code should trigger commit.

sqlspec/extensions/flask/extension.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def _create_config_state(self, config: Any) -> FlaskConfigState:
9696
commit_mode = flask_config.get("commit_mode", DEFAULT_COMMIT_MODE)
9797
extra_commit_statuses = flask_config.get("extra_commit_statuses")
9898
extra_rollback_statuses = flask_config.get("extra_rollback_statuses")
99+
disable_di = flask_config.get("disable_di", False)
99100

100101
is_async = isinstance(config, (AsyncDatabaseConfig, NoPoolAsyncConfig))
101102

@@ -107,6 +108,7 @@ def _create_config_state(self, config: Any) -> FlaskConfigState:
107108
extra_commit_statuses=extra_commit_statuses,
108109
extra_rollback_statuses=extra_rollback_statuses,
109110
is_async=is_async,
111+
disable_di=disable_di,
110112
)
111113

112114
def init_app(self, app: "Flask") -> None:
@@ -143,9 +145,11 @@ def init_app(self, app: "Flask") -> None:
143145

144146
app.extensions["sqlspec"] = {"plugin": self, "pools": pools}
145147

146-
app.before_request(self._before_request_handler)
147-
app.after_request(self._after_request_handler)
148-
app.teardown_appcontext(self._teardown_appcontext_handler)
148+
if any(not state.disable_di for state in self._config_states):
149+
app.before_request(self._before_request_handler)
150+
app.after_request(self._after_request_handler)
151+
app.teardown_appcontext(self._teardown_appcontext_handler)
152+
149153
self._register_shutdown_hook()
150154

151155
logger.debug("SQLSpec Flask extension initialized")
@@ -186,6 +190,9 @@ def _before_request_handler(self) -> None:
186190
from flask import current_app, g
187191

188192
for config_state in self._config_states:
193+
if config_state.disable_di:
194+
continue
195+
189196
if config_state.config.supports_connection_pooling:
190197
pool = current_app.extensions["sqlspec"]["pools"][config_state.session_key]
191198
conn_ctx = config_state.config.provide_connection(pool)
@@ -215,6 +222,9 @@ def _after_request_handler(self, response: "Response") -> "Response":
215222
from flask import g
216223

217224
for config_state in self._config_states:
225+
if config_state.disable_di:
226+
continue
227+
218228
if config_state.commit_mode == "manual":
219229
continue
220230

@@ -242,6 +252,9 @@ def _teardown_appcontext_handler(self, _exc: "BaseException | None" = None) -> N
242252
from flask import g
243253

244254
for config_state in self._config_states:
255+
if config_state.disable_di:
256+
continue
257+
245258
connection = getattr(g, config_state.connection_key, None)
246259
ctx_key = f"{config_state.connection_key}_ctx"
247260
conn_ctx = getattr(g, ctx_key, None)

sqlspec/extensions/litestar/plugin.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class _PluginConfigState:
7272
extra_commit_statuses: "set[int] | None"
7373
extra_rollback_statuses: "set[int] | None"
7474
enable_correlation_middleware: bool
75+
disable_di: bool
7576
connection_provider: "Callable[[State, Scope], AsyncGenerator[Any, None]]" = field(init=False)
7677
pool_provider: "Callable[[State, Scope], Any]" = field(init=False)
7778
session_provider: "Callable[..., AsyncGenerator[Any, None]]" = field(init=False)
@@ -157,6 +158,7 @@ def _extract_litestar_settings(
157158
"extra_commit_statuses": litestar_config.get("extra_commit_statuses"),
158159
"extra_rollback_statuses": litestar_config.get("extra_rollback_statuses"),
159160
"enable_correlation_middleware": litestar_config.get("enable_correlation_middleware", True),
161+
"disable_di": litestar_config.get("disable_di", False),
160162
}
161163

162164
def _create_config_state(
@@ -174,9 +176,11 @@ def _create_config_state(
174176
extra_commit_statuses=settings.get("extra_commit_statuses"),
175177
extra_rollback_statuses=settings.get("extra_rollback_statuses"),
176178
enable_correlation_middleware=settings["enable_correlation_middleware"],
179+
disable_di=settings["disable_di"],
177180
)
178181

179-
self._setup_handlers(state)
182+
if not state.disable_di:
183+
self._setup_handlers(state)
180184
return state
181185

182186
def _setup_handlers(self, state: _PluginConfigState) -> None:
@@ -256,13 +260,14 @@ def store_sqlspec_in_state() -> None:
256260

257261
signature_namespace.update(state.config.get_signature_namespace()) # type: ignore[arg-type]
258262

259-
app_config.before_send.append(state.before_send_handler)
260-
app_config.lifespan.append(state.lifespan_handler)
261-
app_config.dependencies.update({
262-
state.connection_key: Provide(state.connection_provider),
263-
state.pool_key: Provide(state.pool_provider),
264-
state.session_key: Provide(state.session_provider),
265-
})
263+
if not state.disable_di:
264+
app_config.before_send.append(state.before_send_handler)
265+
app_config.lifespan.append(state.lifespan_handler)
266+
app_config.dependencies.update({
267+
state.connection_key: Provide(state.connection_provider),
268+
state.pool_key: Provide(state.pool_provider),
269+
state.session_key: Provide(state.session_provider),
270+
})
266271

267272
if signature_namespace:
268273
app_config.signature_namespace.update(signature_namespace)

sqlspec/extensions/starlette/_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ class SQLSpecConfigState:
2323
commit_mode: CommitMode
2424
extra_commit_statuses: "set[int] | None"
2525
extra_rollback_statuses: "set[int] | None"
26+
disable_di: bool

sqlspec/extensions/starlette/extension.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def _extract_starlette_settings(self, config: Any) -> "dict[str, Any]":
104104
"commit_mode": commit_mode,
105105
"extra_commit_statuses": starlette_config.get("extra_commit_statuses"),
106106
"extra_rollback_statuses": starlette_config.get("extra_rollback_statuses"),
107+
"disable_di": starlette_config.get("disable_di", False),
107108
}
108109

109110
def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSpecConfigState:
@@ -124,6 +125,7 @@ def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSp
124125
commit_mode=settings["commit_mode"],
125126
extra_commit_statuses=settings["extra_commit_statuses"],
126127
extra_rollback_statuses=settings["extra_rollback_statuses"],
128+
disable_di=settings["disable_di"],
127129
)
128130

129131
def init_app(self, app: "Starlette") -> None:
@@ -146,7 +148,8 @@ async def combined_lifespan(app: "Starlette") -> "AsyncGenerator[None, None]":
146148
app.router.lifespan_context = combined_lifespan
147149

148150
for config_state in self._config_states:
149-
self._add_middleware(app, config_state)
151+
if not config_state.disable_di:
152+
self._add_middleware(app, config_state)
150153

151154
def _validate_unique_keys(self) -> None:
152155
"""Validate that all state keys are unique across configs.

tests/integration/test_extensions/test_fastapi_filters_integration.py renamed to tests/integration/test_extensions/test_fastapi/test_fastapi_filters_integration.py

File renamed without changes.

tests/integration/test_extensions/test_fastapi_integration.py renamed to tests/integration/test_extensions/test_fastapi/test_fastapi_integration.py

File renamed without changes.

0 commit comments

Comments
 (0)