Skip to content

Commit 321c97d

Browse files
authored
Merge pull request #26 from PySport/bugfix/connection-not-closed
Bugfix/connection not closed
2 parents 583676e + 8293995 commit 321c97d

File tree

1 file changed

+51
-39
lines changed

1 file changed

+51
-39
lines changed

ingestify/infra/store/dataset/sqlalchemy/repository.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
and_,
1616
Column,
1717
or_,
18+
Dialect,
1819
)
1920
from sqlalchemy.engine import make_url
2021
from sqlalchemy.exc import NoSuchModuleError
@@ -96,6 +97,7 @@ def _init_engine(self):
9697
# Use the default isolation level, don't need SERIALIZABLE
9798
# isolation_level="SERIALIZABLE",
9899
)
100+
self.dialect = self.engine.dialect
99101
self.session = Session(bind=self.engine)
100102

101103
def __init__(self, url: str):
@@ -113,18 +115,18 @@ def __setstate__(self, state):
113115
self.url = state["url"]
114116
self._init_engine()
115117

116-
def _close_engine(self):
117-
if hasattr(self, "session"):
118-
self.session.close()
119-
self.engine.dispose()
120-
121118
def __del__(self):
122-
self._close_engine()
119+
self.close()
123120

124121
def reset(self):
125-
self._close_engine()
122+
self.close()
126123
self._init_engine()
127124

125+
def close(self):
126+
if hasattr(self, "session"):
127+
self.session.close()
128+
self.engine.dispose()
129+
128130
def get(self):
129131
return self.session
130132

@@ -141,8 +143,12 @@ def __init__(self, session_provider: SqlAlchemySessionProvider):
141143
def session(self):
142144
return self.session_provider.get()
143145

146+
@property
147+
def dialect(self) -> Dialect:
148+
return self.session_provider.dialect
149+
144150
def _upsert(self, connection: Connection, table: Table, entities: list[dict]):
145-
dialect = self.session.bind.dialect.name
151+
dialect = self.dialect.name
146152
if dialect == "mysql":
147153
from sqlalchemy.dialects.mysql import insert
148154
elif dialect == "postgresql":
@@ -186,7 +192,7 @@ def _filter_query(
186192
else:
187193
query = query.filter(dataset_table.c.dataset_id == dataset_id)
188194

189-
dialect = self.session.bind.dialect.name
195+
dialect = self.dialect.name
190196

191197
if not isinstance(selector, list):
192198
where, selector = selector.split("where")
@@ -249,7 +255,7 @@ def _filter_query(
249255

250256
return query
251257

252-
def load_datasets(self, dataset_ids: list[str]) -> list[Dataset]:
258+
def _load_datasets(self, dataset_ids: list[str]) -> list[Dataset]:
253259
if not dataset_ids:
254260
return []
255261

@@ -305,7 +311,7 @@ def load_datasets(self, dataset_ids: list[str]) -> list[Dataset]:
305311

306312
def _debug_query(self, q: Query):
307313
text_ = q.statement.compile(
308-
compile_kwargs={"literal_binds": True}, dialect=self.session.bind.dialect
314+
compile_kwargs={"literal_binds": True}, dialect=self.dialect
309315
)
310316
logger.debug(f"Running query: {text_}")
311317

@@ -328,37 +334,40 @@ def apply_query_filter(query):
328334
selector=selector,
329335
)
330336

331-
if not metadata_only:
332-
dataset_query = apply_query_filter(
333-
self.session.query(dataset_table.c.dataset_id)
334-
)
335-
self._debug_query(dataset_query)
336-
dataset_ids = [row.dataset_id for row in dataset_query]
337-
datasets = self.load_datasets(dataset_ids)
338-
339-
dataset_collection_metadata = DatasetCollectionMetadata(
340-
last_modified=max(dataset.last_modified_at for dataset in datasets)
341-
if datasets
342-
else None,
343-
row_count=len(datasets),
344-
)
345-
else:
346-
datasets = []
347-
348-
metadata_result_query = apply_query_filter(
349-
self.session.query(
350-
func.max(dataset_table.c.last_modified_at).label(
351-
"last_modified_at"
352-
),
353-
func.count().label("row_count"),
337+
with self.session:
338+
# Use a contextmanager to make sure it's closed afterwards
339+
340+
if not metadata_only:
341+
dataset_query = apply_query_filter(
342+
self.session.query(dataset_table.c.dataset_id)
343+
)
344+
self._debug_query(dataset_query)
345+
dataset_ids = [row.dataset_id for row in dataset_query]
346+
datasets = self._load_datasets(dataset_ids)
347+
348+
dataset_collection_metadata = DatasetCollectionMetadata(
349+
last_modified=max(dataset.last_modified_at for dataset in datasets)
350+
if datasets
351+
else None,
352+
row_count=len(datasets),
353+
)
354+
else:
355+
datasets = []
356+
357+
metadata_result_query = apply_query_filter(
358+
self.session.query(
359+
func.max(dataset_table.c.last_modified_at).label(
360+
"last_modified_at"
361+
),
362+
func.count().label("row_count"),
363+
)
354364
)
355-
)
356365

357-
self._debug_query(metadata_result_query)
366+
self._debug_query(metadata_result_query)
358367

359-
dataset_collection_metadata = DatasetCollectionMetadata(
360-
*metadata_result_query.first()
361-
)
368+
dataset_collection_metadata = DatasetCollectionMetadata(
369+
*metadata_result_query.first()
370+
)
362371

363372
return DatasetCollection(dataset_collection_metadata, datasets)
364373

@@ -371,6 +380,9 @@ def save(self, bucket: str, dataset: Dataset):
371380
def connect(self):
372381
return self.session_provider.engine.connect()
373382

383+
def __del__(self):
384+
self.session_provider.close()
385+
374386
def _save(self, datasets: list[Dataset]):
375387
"""Only do upserts. Never delete. Rows get only deleted when an entire Dataset is removed."""
376388
datasets_entities = []

0 commit comments

Comments
 (0)