1515 and_ ,
1616 Column ,
1717 or_ ,
18+ Dialect ,
1819)
1920from sqlalchemy .engine import make_url
2021from sqlalchemy .exc import NoSuchModuleError
@@ -96,6 +97,7 @@ def _init_engine(self):
9697 # Use the default isolation level, don't need SERIALIZABLE
9798 # isolation_level="SERIALIZABLE",
9899 )
100+ self .dialect = self .engine .dialect
99101 self .session = Session (bind = self .engine )
100102
101103 def __init__ (self , url : str ):
@@ -113,18 +115,18 @@ def __setstate__(self, state):
113115 self .url = state ["url" ]
114116 self ._init_engine ()
115117
116- def _close_engine (self ):
117- if hasattr (self , "session" ):
118- self .session .close ()
119- self .engine .dispose ()
120-
121118 def __del__ (self ):
122- self ._close_engine ()
119+ self .close ()
123120
124121 def reset (self ):
125- self ._close_engine ()
122+ self .close ()
126123 self ._init_engine ()
127124
125+ def close (self ):
126+ if hasattr (self , "session" ):
127+ self .session .close ()
128+ self .engine .dispose ()
129+
128130 def get (self ):
129131 return self .session
130132
@@ -141,8 +143,12 @@ def __init__(self, session_provider: SqlAlchemySessionProvider):
141143 def session (self ):
142144 return self .session_provider .get ()
143145
146+ @property
147+ def dialect (self ) -> Dialect :
148+ return self .session_provider .dialect
149+
144150 def _upsert (self , connection : Connection , table : Table , entities : list [dict ]):
145- dialect = self .session . bind . dialect .name
151+ dialect = self .dialect .name
146152 if dialect == "mysql" :
147153 from sqlalchemy .dialects .mysql import insert
148154 elif dialect == "postgresql" :
@@ -186,7 +192,7 @@ def _filter_query(
186192 else :
187193 query = query .filter (dataset_table .c .dataset_id == dataset_id )
188194
189- dialect = self .session . bind . dialect .name
195+ dialect = self .dialect .name
190196
191197 if not isinstance (selector , list ):
192198 where , selector = selector .split ("where" )
@@ -249,7 +255,7 @@ def _filter_query(
249255
250256 return query
251257
252- def load_datasets (self , dataset_ids : list [str ]) -> list [Dataset ]:
258+ def _load_datasets (self , dataset_ids : list [str ]) -> list [Dataset ]:
253259 if not dataset_ids :
254260 return []
255261
@@ -305,7 +311,7 @@ def load_datasets(self, dataset_ids: list[str]) -> list[Dataset]:
305311
306312 def _debug_query (self , q : Query ):
307313 text_ = q .statement .compile (
308- compile_kwargs = {"literal_binds" : True }, dialect = self .session . bind . dialect
314+ compile_kwargs = {"literal_binds" : True }, dialect = self .dialect
309315 )
310316 logger .debug (f"Running query: { text_ } " )
311317
@@ -328,37 +334,40 @@ def apply_query_filter(query):
328334 selector = selector ,
329335 )
330336
331- if not metadata_only :
332- dataset_query = apply_query_filter (
333- self .session .query (dataset_table .c .dataset_id )
334- )
335- self ._debug_query (dataset_query )
336- dataset_ids = [row .dataset_id for row in dataset_query ]
337- datasets = self .load_datasets (dataset_ids )
338-
339- dataset_collection_metadata = DatasetCollectionMetadata (
340- last_modified = max (dataset .last_modified_at for dataset in datasets )
341- if datasets
342- else None ,
343- row_count = len (datasets ),
344- )
345- else :
346- datasets = []
347-
348- metadata_result_query = apply_query_filter (
349- self .session .query (
350- func .max (dataset_table .c .last_modified_at ).label (
351- "last_modified_at"
352- ),
353- func .count ().label ("row_count" ),
337+ with self .session :
338+ # Use a contextmanager to make sure it's closed afterwards
339+
340+ if not metadata_only :
341+ dataset_query = apply_query_filter (
342+ self .session .query (dataset_table .c .dataset_id )
343+ )
344+ self ._debug_query (dataset_query )
345+ dataset_ids = [row .dataset_id for row in dataset_query ]
346+ datasets = self ._load_datasets (dataset_ids )
347+
348+ dataset_collection_metadata = DatasetCollectionMetadata (
349+ last_modified = max (dataset .last_modified_at for dataset in datasets )
350+ if datasets
351+ else None ,
352+ row_count = len (datasets ),
353+ )
354+ else :
355+ datasets = []
356+
357+ metadata_result_query = apply_query_filter (
358+ self .session .query (
359+ func .max (dataset_table .c .last_modified_at ).label (
360+ "last_modified_at"
361+ ),
362+ func .count ().label ("row_count" ),
363+ )
354364 )
355- )
356365
357- self ._debug_query (metadata_result_query )
366+ self ._debug_query (metadata_result_query )
358367
359- dataset_collection_metadata = DatasetCollectionMetadata (
360- * metadata_result_query .first ()
361- )
368+ dataset_collection_metadata = DatasetCollectionMetadata (
369+ * metadata_result_query .first ()
370+ )
362371
363372 return DatasetCollection (dataset_collection_metadata , datasets )
364373
@@ -371,6 +380,9 @@ def save(self, bucket: str, dataset: Dataset):
371380 def connect (self ):
372381 return self .session_provider .engine .connect ()
373382
383+ def __del__ (self ):
384+ self .session_provider .close ()
385+
374386 def _save (self , datasets : list [Dataset ]):
375387 """Only do upserts. Never delete. Rows get only deleted when an entire Dataset is removed."""
376388 datasets_entities = []
0 commit comments