55# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66from __future__ import annotations
77
8- from typing import Any
8+ from typing import Any , AsyncContextManager
99from contextlib import asynccontextmanager
1010import asyncio
1111import logging
@@ -81,6 +81,7 @@ async def fetchval(
8181
8282class SQLiteDatabase (Database ):
8383 scheme = Scheme .SQLITE
84+ _parent : SQLiteDatabase | None
8485 _pool : asyncio .Queue [TxnConnection ]
8586 _stopped : bool
8687 _conns : int
@@ -103,6 +104,7 @@ def __init__(
103104 owner_name = owner_name ,
104105 ignore_foreign_tables = ignore_foreign_tables ,
105106 )
107+ self ._parent = None
106108 self ._path = url .path
107109 if self ._path .startswith ("/" ):
108110 self ._path = self ._path [1 :]
@@ -134,7 +136,14 @@ def _add_missing_pragmas(init_commands: list[str]) -> list[str]:
134136 init_commands .append ("PRAGMA busy_timeout = 5000" )
135137 return init_commands
136138
139+ def override_pool (self , db : Database ) -> None :
140+ assert isinstance (db , SQLiteDatabase )
141+ self ._parent = db
142+
137143 async def start (self ) -> None :
144+ if self ._parent :
145+ await super ().start ()
146+ return
138147 if self ._conns :
139148 raise RuntimeError ("database pool has already been started" )
140149 elif self ._stopped :
@@ -155,14 +164,21 @@ async def start(self) -> None:
155164 await super ().start ()
156165
157166 async def stop (self ) -> None :
167+ if self ._parent :
168+ return
158169 self ._stopped = True
159170 while self ._conns > 0 :
160171 conn = await self ._pool .get ()
161172 self ._conns -= 1
162173 await conn .close ()
163174
175+ def acquire (self ) -> AsyncContextManager [LoggingConnection ]:
176+ if self ._parent :
177+ return self ._parent .acquire ()
178+ return self ._acquire ()
179+
164180 @asynccontextmanager
165- async def acquire (self ) -> LoggingConnection :
181+ async def _acquire (self ) -> LoggingConnection :
166182 if self ._stopped :
167183 raise RuntimeError ("database pool has been stopped" )
168184 conn = await self ._pool .get ()
0 commit comments