Skip to content

Commit de7b813

Browse files
authored
fix: litestar plugin session provider (#332)
Correctly handle cases where the session was requested but had not been populated in the Litestar request state.
1 parent c90f4dc commit de7b813

File tree

3 files changed

+601
-95
lines changed

3 files changed

+601
-95
lines changed

sqlspec/extensions/litestar/plugin.py

Lines changed: 261 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,87 @@ def get_config(
530530
msg = f"No database configuration found for name '{name}'. Available keys: {self._get_available_keys()}"
531531
raise KeyError(msg)
532532

533+
def _ensure_connection_sync(self, plugin_state: PluginConfigState, state: "State", scope: "Scope") -> Any:
534+
"""Ensure a connection exists in scope, creating one from the pool if needed (sync)."""
535+
connection = get_sqlspec_scope_state(scope, plugin_state.connection_key)
536+
if connection is not None:
537+
return connection
538+
539+
pool = state.get(plugin_state.pool_key)
540+
if pool is None:
541+
self._raise_missing_connection(plugin_state.connection_key)
542+
543+
cm = plugin_state.config.provide_connection(pool)
544+
connection = cm.__enter__() # type: ignore[union-attr]
545+
set_sqlspec_scope_state(scope, plugin_state.connection_key, connection)
546+
return connection
547+
548+
async def _ensure_connection_async(self, plugin_state: PluginConfigState, state: "State", scope: "Scope") -> Any:
549+
"""Ensure a connection exists in scope, creating one from the pool if needed (async)."""
550+
connection = get_sqlspec_scope_state(scope, plugin_state.connection_key)
551+
if connection is not None:
552+
return connection
553+
554+
pool = state.get(plugin_state.pool_key)
555+
if pool is None:
556+
self._raise_missing_connection(plugin_state.connection_key)
557+
558+
cm = plugin_state.config.provide_connection(pool)
559+
connection = await cm.__aenter__() # type: ignore[union-attr]
560+
set_sqlspec_scope_state(scope, plugin_state.connection_key, connection)
561+
return connection
562+
563+
def _create_session(
564+
self, plugin_state: PluginConfigState, connection: Any, scope: "Scope"
565+
) -> "SyncDriverAdapterBase | AsyncDriverAdapterBase":
566+
"""Create a session from a connection and store it in scope."""
567+
session_scope_key = f"{plugin_state.session_key}_instance"
568+
569+
session = get_sqlspec_scope_state(scope, session_scope_key)
570+
if session is not None:
571+
return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session)
572+
573+
session = plugin_state.config.driver_type(
574+
connection=connection,
575+
statement_config=plugin_state.config.statement_config,
576+
driver_features=plugin_state.config.driver_features,
577+
)
578+
set_sqlspec_scope_state(scope, session_scope_key, session)
579+
return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session)
580+
581+
@overload
582+
def provide_request_session(
583+
self,
584+
key: "SyncDatabaseConfig[Any, Any, DriverT] | NoPoolSyncConfig[Any, DriverT] | type[SyncDatabaseConfig[Any, Any, DriverT] | NoPoolSyncConfig[Any, DriverT]]",
585+
state: "State",
586+
scope: "Scope",
587+
) -> "DriverT": ...
588+
589+
@overload
590+
def provide_request_session(
591+
self,
592+
key: "AsyncDatabaseConfig[Any, Any, DriverT] | NoPoolAsyncConfig[Any, DriverT] | type[AsyncDatabaseConfig[Any, Any, DriverT] | NoPoolAsyncConfig[Any, DriverT]]",
593+
state: "State",
594+
scope: "Scope",
595+
) -> "DriverT": ...
596+
597+
@overload
533598
def provide_request_session(
534-
self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]", state: "State", scope: "Scope"
599+
self, key: str, state: "State", scope: "Scope"
600+
) -> "SyncDriverAdapterBase | AsyncDriverAdapterBase": ...
601+
602+
def provide_request_session(
603+
self,
604+
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
605+
state: "State",
606+
scope: "Scope",
535607
) -> "SyncDriverAdapterBase | AsyncDriverAdapterBase":
536608
"""Provide a database session for the specified configuration key from request scope.
537609
610+
This method requires the connection to already exist in scope (e.g., from DI injection).
611+
For on-demand connection creation, use ``provide_request_session_sync`` or
612+
``provide_request_session_async`` instead.
613+
538614
Args:
539615
key: The configuration identifier (same as get_config).
540616
state: The Litestar application State object.
@@ -544,62 +620,132 @@ def provide_request_session(
544620
A driver session instance for the specified database configuration.
545621
"""
546622
plugin_state = self._get_plugin_state(key)
547-
session_scope_key = f"{plugin_state.session_key}_instance"
548-
549-
session = get_sqlspec_scope_state(scope, session_scope_key)
550-
if session is not None:
551-
return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session)
552-
553623
connection = get_sqlspec_scope_state(scope, plugin_state.connection_key)
554624
if connection is None:
555625
self._raise_missing_connection(plugin_state.connection_key)
626+
return self._create_session(plugin_state, connection, scope)
556627

557-
session = plugin_state.config.driver_type(
558-
connection=connection,
559-
statement_config=plugin_state.config.statement_config,
560-
driver_features=plugin_state.config.driver_features,
561-
)
562-
set_sqlspec_scope_state(scope, session_scope_key, session)
628+
@overload
629+
def provide_request_session_sync(
630+
self,
631+
key: "SyncDatabaseConfig[Any, Any, DriverT] | NoPoolSyncConfig[Any, DriverT]",
632+
state: "State",
633+
scope: "Scope",
634+
) -> "DriverT": ...
563635

564-
return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session)
636+
@overload
637+
def provide_request_session_sync(
638+
self,
639+
key: "type[SyncDatabaseConfig[Any, Any, DriverT] | NoPoolSyncConfig[Any, DriverT]]",
640+
state: "State",
641+
scope: "Scope",
642+
) -> "DriverT": ...
643+
644+
@overload
645+
def provide_request_session_sync(self, key: str, state: "State", scope: "Scope") -> "SyncDriverAdapterBase": ...
565646

566-
def provide_sync_request_session(
567-
self, key: "str | SyncConfigT | type[SyncConfigT]", state: "State", scope: "Scope"
568-
) -> "SyncDriverAdapterBase":
647+
def provide_request_session_sync(
648+
self,
649+
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]]",
650+
state: "State",
651+
scope: "Scope",
652+
) -> "SyncDriverAdapterBase | Any":
569653
"""Provide a sync database session for the specified configuration key from request scope.
570654
655+
If no connection exists in scope, one will be created from the pool and stored
656+
in scope for reuse. The connection will be cleaned up by the before_send handler.
657+
658+
For async configurations, use ``provide_request_session_async`` instead.
659+
571660
Args:
572-
key: The sync configuration identifier.
661+
key: The configuration identifier (same as get_config).
573662
state: The Litestar application State object.
574663
scope: The ASGI scope containing the request context.
575664
576665
Returns:
577666
A sync driver session instance for the specified database configuration.
578667
"""
579-
session = self.provide_request_session(key, state, scope)
580-
return cast("SyncDriverAdapterBase", session)
668+
plugin_state = self._get_plugin_state(key)
669+
connection = self._ensure_connection_sync(plugin_state, state, scope)
670+
return cast("SyncDriverAdapterBase", self._create_session(plugin_state, connection, scope))
671+
672+
@overload
673+
async def provide_request_session_async(
674+
self,
675+
key: "AsyncDatabaseConfig[Any, Any, DriverT] | NoPoolAsyncConfig[Any, DriverT]",
676+
state: "State",
677+
scope: "Scope",
678+
) -> "DriverT": ...
581679

582-
def provide_async_request_session(
583-
self, key: "str | AsyncConfigT | type[AsyncConfigT]", state: "State", scope: "Scope"
584-
) -> "AsyncDriverAdapterBase":
680+
@overload
681+
async def provide_request_session_async(
682+
self,
683+
key: "type[AsyncDatabaseConfig[Any, Any, DriverT] | NoPoolAsyncConfig[Any, DriverT]]",
684+
state: "State",
685+
scope: "Scope",
686+
) -> "DriverT": ...
687+
688+
@overload
689+
async def provide_request_session_async(
690+
self, key: str, state: "State", scope: "Scope"
691+
) -> "AsyncDriverAdapterBase": ...
692+
693+
async def provide_request_session_async(
694+
self,
695+
key: "str | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
696+
state: "State",
697+
scope: "Scope",
698+
) -> "AsyncDriverAdapterBase | Any":
585699
"""Provide an async database session for the specified configuration key from request scope.
586700
701+
If no connection exists in scope, one will be created from the pool and stored
702+
in scope for reuse. The connection will be cleaned up by the before_send handler.
703+
704+
For sync configurations, use ``provide_request_session`` instead.
705+
587706
Args:
588-
key: The async configuration identifier.
707+
key: The configuration identifier (same as get_config).
589708
state: The Litestar application State object.
590709
scope: The ASGI scope containing the request context.
591710
592711
Returns:
593712
An async driver session instance for the specified database configuration.
594713
"""
595-
session = self.provide_request_session(key, state, scope)
596-
return cast("AsyncDriverAdapterBase", session)
714+
plugin_state = self._get_plugin_state(key)
715+
connection = await self._ensure_connection_async(plugin_state, state, scope)
716+
return cast("AsyncDriverAdapterBase", self._create_session(plugin_state, connection, scope))
717+
718+
@overload
719+
def provide_request_connection(
720+
self,
721+
key: "SyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolSyncConfig[ConnectionT, Any] | AsyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolAsyncConfig[ConnectionT, Any]",
722+
state: "State",
723+
scope: "Scope",
724+
) -> "ConnectionT": ...
725+
726+
@overload
727+
def provide_request_connection(
728+
self,
729+
key: "type[SyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolSyncConfig[ConnectionT, Any] | AsyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolAsyncConfig[ConnectionT, Any]]",
730+
state: "State",
731+
scope: "Scope",
732+
) -> "ConnectionT": ...
733+
734+
@overload
735+
def provide_request_connection(self, key: str, state: "State", scope: "Scope") -> Any: ...
597736

598737
def provide_request_connection(
599-
self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]", state: "State", scope: "Scope"
600-
) -> "Any":
738+
self,
739+
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
740+
state: "State",
741+
scope: "Scope",
742+
) -> Any:
601743
"""Provide a database connection for the specified configuration key from request scope.
602744
745+
This method requires the connection to already exist in scope (e.g., from DI injection).
746+
For on-demand connection creation, use ``provide_request_connection_sync`` or
747+
``provide_request_connection_async`` instead.
748+
603749
Args:
604750
key: The configuration identifier (same as get_config).
605751
state: The Litestar application State object.
@@ -612,11 +758,96 @@ def provide_request_connection(
612758
connection = get_sqlspec_scope_state(scope, plugin_state.connection_key)
613759
if connection is None:
614760
self._raise_missing_connection(plugin_state.connection_key)
615-
616761
return connection
617762

763+
@overload
764+
def provide_request_connection_sync(
765+
self,
766+
key: "SyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolSyncConfig[ConnectionT, Any]",
767+
state: "State",
768+
scope: "Scope",
769+
) -> "ConnectionT": ...
770+
771+
@overload
772+
def provide_request_connection_sync(
773+
self,
774+
key: "type[SyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolSyncConfig[ConnectionT, Any]]",
775+
state: "State",
776+
scope: "Scope",
777+
) -> "ConnectionT": ...
778+
779+
@overload
780+
def provide_request_connection_sync(self, key: str, state: "State", scope: "Scope") -> Any: ...
781+
782+
def provide_request_connection_sync(
783+
self,
784+
key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]]",
785+
state: "State",
786+
scope: "Scope",
787+
) -> Any:
788+
"""Provide a sync database connection for the specified configuration key from request scope.
789+
790+
If no connection exists in scope, one will be created from the pool and stored
791+
in scope for reuse. The connection will be cleaned up by the before_send handler.
792+
793+
For async configurations, use ``provide_request_connection_async`` instead.
794+
795+
Args:
796+
key: The configuration identifier (same as get_config).
797+
state: The Litestar application State object.
798+
scope: The ASGI scope containing the request context.
799+
800+
Returns:
801+
A database connection instance for the specified database configuration.
802+
"""
803+
plugin_state = self._get_plugin_state(key)
804+
return self._ensure_connection_sync(plugin_state, state, scope)
805+
806+
@overload
807+
async def provide_request_connection_async(
808+
self,
809+
key: "AsyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolAsyncConfig[ConnectionT, Any]",
810+
state: "State",
811+
scope: "Scope",
812+
) -> "ConnectionT": ...
813+
814+
@overload
815+
async def provide_request_connection_async(
816+
self,
817+
key: "type[AsyncDatabaseConfig[ConnectionT, Any, Any] | NoPoolAsyncConfig[ConnectionT, Any]]",
818+
state: "State",
819+
scope: "Scope",
820+
) -> "ConnectionT": ...
821+
822+
@overload
823+
async def provide_request_connection_async(self, key: str, state: "State", scope: "Scope") -> Any: ...
824+
825+
async def provide_request_connection_async(
826+
self,
827+
key: "str | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]",
828+
state: "State",
829+
scope: "Scope",
830+
) -> Any:
831+
"""Provide an async database connection for the specified configuration key from request scope.
832+
833+
If no connection exists in scope, one will be created from the pool and stored
834+
in scope for reuse. The connection will be cleaned up by the before_send handler.
835+
836+
For sync configurations, use ``provide_request_connection`` instead.
837+
838+
Args:
839+
key: The configuration identifier (same as get_config).
840+
state: The Litestar application State object.
841+
scope: The ASGI scope containing the request context.
842+
843+
Returns:
844+
A database connection instance for the specified database configuration.
845+
"""
846+
plugin_state = self._get_plugin_state(key)
847+
return await self._ensure_connection_async(plugin_state, state, scope)
848+
618849
def _get_plugin_state(
619-
self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]"
850+
self, key: "str | DatabaseConfigProtocol[Any, Any, Any] | type[DatabaseConfigProtocol[Any, Any, Any]]"
620851
) -> PluginConfigState:
621852
"""Get plugin state for a configuration by key."""
622853
if isinstance(key, str):

0 commit comments

Comments
 (0)