1515)
1616
1717import sqlalchemy
18- from attrs import frozen
1918from sqlalchemy import MetaData , Table , UniqueConstraint , exists , select
2019from sqlalchemy .dialects import sqlite
2120from sqlalchemy .schema import CreateIndex , CreateTable , DropTable
4039
4140if TYPE_CHECKING :
4241 from sqlalchemy .dialects .sqlite import Insert
42+ from sqlalchemy .engine .base import Engine
4343 from sqlalchemy .schema import SchemaItem
4444 from sqlalchemy .sql .elements import ColumnClause , ColumnElement , TextClause
4545 from sqlalchemy .sql .selectable import Select
5252RETRY_MAX_TIMES = 10
5353RETRY_FACTOR = 2
5454
55+ DETECT_TYPES = sqlite3 .PARSE_DECLTYPES | sqlite3 .PARSE_COLNAMES
56+
5557Column = Union [str , "ColumnClause[Any]" , "TextClause" ]
5658
5759datachain .sql .sqlite .setup ()
@@ -80,26 +82,41 @@ def wrapper(*args, **kwargs):
8082 return wrapper
8183
8284
83- @frozen
8485class SQLiteDatabaseEngine (DatabaseEngine ):
8586 dialect = sqlite_dialect
8687
8788 db : sqlite3 .Connection
8889 db_file : Optional [str ]
90+ is_closed : bool
91+
92+ def __init__ (
93+ self ,
94+ engine : "Engine" ,
95+ metadata : "MetaData" ,
96+ db : sqlite3 .Connection ,
97+ db_file : Optional [str ] = None ,
98+ ):
99+ self .engine = engine
100+ self .metadata = metadata
101+ self .db = db
102+ self .db_file = db_file
103+ self .is_closed = False
89104
90105 @classmethod
91106 def from_db_file (cls , db_file : Optional [str ] = None ) -> "SQLiteDatabaseEngine" :
92- detect_types = sqlite3 . PARSE_DECLTYPES | sqlite3 . PARSE_COLNAMES
107+ return cls ( * cls . _connect ( db_file = db_file ))
93108
109+ @staticmethod
110+ def _connect (db_file : Optional [str ] = None ):
94111 try :
95112 if db_file == ":memory:" :
96113 # Enable multithreaded usage of the same in-memory db
97114 db = sqlite3 .connect (
98- "file::memory:?cache=shared" , uri = True , detect_types = detect_types
115+ "file::memory:?cache=shared" , uri = True , detect_types = DETECT_TYPES
99116 )
100117 else :
101118 db = sqlite3 .connect (
102- db_file or DataChainDir .find ().db , detect_types = detect_types
119+ db_file or DataChainDir .find ().db , detect_types = DETECT_TYPES
103120 )
104121 create_user_defined_sql_functions (db )
105122 engine = sqlalchemy .create_engine (
@@ -118,7 +135,7 @@ def from_db_file(cls, db_file: Optional[str] = None) -> "SQLiteDatabaseEngine":
118135
119136 load_usearch_extension (db )
120137
121- return cls ( engine , MetaData (), db , db_file )
138+ return engine , MetaData (), db , db_file
122139 except RuntimeError :
123140 raise DataChainError ("Can't connect to SQLite DB" ) from None
124141
@@ -138,13 +155,26 @@ def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
138155 {},
139156 )
140157
158+ def _reconnect (self ) -> None :
159+ if not self .is_closed :
160+ raise RuntimeError ("Cannot reconnect on still-open DB!" )
161+ engine , metadata , db , db_file = self ._connect (db_file = self .db_file )
162+ self .engine = engine
163+ self .metadata = metadata
164+ self .db = db
165+ self .db_file = db_file
166+ self .is_closed = False
167+
141168 @retry_sqlite_locks
142169 def execute (
143170 self ,
144171 query ,
145172 cursor : Optional [sqlite3 .Cursor ] = None ,
146173 conn = None ,
147174 ) -> sqlite3 .Cursor :
175+ if self .is_closed :
176+ # Reconnect in case of being closed previously.
177+ self ._reconnect ()
148178 if cursor is not None :
149179 result = cursor .execute (* self .compile_to_args (query ))
150180 elif conn is not None :
@@ -179,6 +209,7 @@ def cursor(self, factory=None):
179209
180210 def close (self ) -> None :
181211 self .db .close ()
212+ self .is_closed = True
182213
183214 @contextmanager
184215 def transaction (self ):
@@ -359,6 +390,10 @@ def __init__(
359390
360391 self ._init_tables ()
361392
393+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
394+ """Close connection upon exit from context manager."""
395+ self .close ()
396+
362397 def clone (
363398 self ,
364399 uri : StorageURI = StorageURI ("" ),
@@ -521,6 +556,10 @@ def __init__(
521556
522557 self .db = db or SQLiteDatabaseEngine .from_db_file (db_file )
523558
559+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
560+ """Close connection upon exit from context manager."""
561+ self .close ()
562+
524563 def clone (self , use_new_connection : bool = False ) -> "SQLiteWarehouse" :
525564 return SQLiteWarehouse (self .id_generator .clone (), db = self .db .clone ())
526565
0 commit comments