Skip to content

Commit 8701ba5

Browse files
authored
Fixing too many files open, and adding reconnect (#229)
1 parent 1e5178b commit 8701ba5

21 files changed

+616
-391
lines changed

src/datachain/catalog/catalog.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -236,36 +236,36 @@ def do_task(self, urls):
236236
import lz4.frame
237237
import pandas as pd
238238

239-
metastore = self.metastore.clone() # metastore is not thread safe
240-
warehouse = self.warehouse.clone() # warehouse is not thread safe
241-
dataset = metastore.get_dataset(self.dataset_name)
242-
243-
urls = list(urls)
244-
while urls:
245-
for url in urls:
246-
if self.should_check_for_status():
247-
self.check_for_status()
248-
249-
r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
250-
if r.status_code == 404:
251-
time.sleep(PULL_DATASET_SLEEP_INTERVAL)
252-
# moving to the next url
253-
continue
239+
# metastore and warehouse are not thread safe
240+
with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse:
241+
dataset = metastore.get_dataset(self.dataset_name)
254242

255-
r.raise_for_status()
243+
urls = list(urls)
244+
while urls:
245+
for url in urls:
246+
if self.should_check_for_status():
247+
self.check_for_status()
256248

257-
df = pd.read_parquet(io.BytesIO(lz4.frame.decompress(r.content)))
249+
r = requests.get(url, timeout=PULL_DATASET_CHUNK_TIMEOUT)
250+
if r.status_code == 404:
251+
time.sleep(PULL_DATASET_SLEEP_INTERVAL)
252+
# moving to the next url
253+
continue
258254

259-
self.fix_columns(df)
255+
r.raise_for_status()
260256

261-
# id will be autogenerated in DB
262-
df = df.drop("sys__id", axis=1)
257+
df = pd.read_parquet(io.BytesIO(lz4.frame.decompress(r.content)))
263258

264-
inserted = warehouse.insert_dataset_rows(
265-
df, dataset, self.dataset_version
266-
)
267-
self.increase_counter(inserted) # type: ignore [arg-type]
268-
urls.remove(url)
259+
self.fix_columns(df)
260+
261+
# id will be autogenerated in DB
262+
df = df.drop("sys__id", axis=1)
263+
264+
inserted = warehouse.insert_dataset_rows(
265+
df, dataset, self.dataset_version
266+
)
267+
self.increase_counter(inserted) # type: ignore [arg-type]
268+
urls.remove(url)
269269

270270

271271
@dataclass
@@ -720,7 +720,6 @@ def enlist_source(
720720
client.uri, posixpath.join(prefix, "")
721721
)
722722
source_metastore = self.metastore.clone(client.uri)
723-
source_warehouse = self.warehouse.clone()
724723

725724
columns = [
726725
Column("vtype", String),
@@ -1835,25 +1834,29 @@ def _instantiate_dataset():
18351834
if signed_urls:
18361835
shuffle(signed_urls)
18371836

1838-
rows_fetcher = DatasetRowsFetcher(
1839-
self.metastore.clone(),
1840-
self.warehouse.clone(),
1841-
remote_config,
1842-
dataset.name,
1843-
version,
1844-
schema,
1845-
)
1846-
try:
1847-
rows_fetcher.run(
1848-
batched(
1849-
signed_urls,
1850-
math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS),
1851-
),
1852-
dataset_save_progress_bar,
1837+
with (
1838+
self.metastore.clone() as metastore,
1839+
self.warehouse.clone() as warehouse,
1840+
):
1841+
rows_fetcher = DatasetRowsFetcher(
1842+
metastore,
1843+
warehouse,
1844+
remote_config,
1845+
dataset.name,
1846+
version,
1847+
schema,
18531848
)
1854-
except:
1855-
self.remove_dataset(dataset.name, version)
1856-
raise
1849+
try:
1850+
rows_fetcher.run(
1851+
batched(
1852+
signed_urls,
1853+
math.ceil(len(signed_urls) / PULL_DATASET_MAX_THREADS),
1854+
),
1855+
dataset_save_progress_bar,
1856+
)
1857+
except:
1858+
self.remove_dataset(dataset.name, version)
1859+
raise
18571860

18581861
dataset = self.metastore.update_dataset_status(
18591862
dataset,

src/datachain/data_storage/db_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
55

66
import sqlalchemy as sa
7-
from attrs import frozen
87
from sqlalchemy.sql import FROM_LINTING
98
from sqlalchemy.sql.roles import DDLRole
109

@@ -23,13 +22,18 @@
2322
SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
2423

2524

26-
@frozen
2725
class DatabaseEngine(ABC, Serializable):
2826
dialect: ClassVar["Dialect"]
2927

3028
engine: "Engine"
3129
metadata: "MetaData"
3230

31+
def __enter__(self) -> "DatabaseEngine":
32+
return self
33+
34+
def __exit__(self, exc_type, exc_value, traceback) -> None:
35+
self.close()
36+
3337
@abstractmethod
3438
def clone(self) -> "DatabaseEngine":
3539
"""Clones DatabaseEngine implementation."""

src/datachain/data_storage/id_generator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ def init(self) -> None:
3333
def cleanup_for_tests(self):
3434
"""Cleanup for tests."""
3535

36+
def close(self) -> None:
37+
"""Closes any active database connections."""
38+
39+
def close_on_exit(self) -> None:
40+
"""Closes any active database or HTTP connections, called on Session exit or
41+
for test cleanup only, as some ID Generator implementations may handle this
42+
differently.
43+
"""
44+
self.close()
45+
3646
@abstractmethod
3747
def init_id(self, uri: str) -> None:
3848
"""Initializes the ID generator for the given URI with zero last_id."""
@@ -83,6 +93,10 @@ def __init__(
8393
def clone(self) -> "AbstractDBIDGenerator":
8494
"""Clones AbstractIDGenerator implementation."""
8595

96+
def close(self) -> None:
97+
"""Closes any active database connections."""
98+
self.db.close()
99+
86100
@property
87101
def db(self) -> "DatabaseEngine":
88102
return self._db

src/datachain/data_storage/metastore.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ def __init__(
7878
self.uri = uri
7979
self.partial_id: Optional[int] = partial_id
8080

81+
def __enter__(self) -> "AbstractMetastore":
82+
"""Returns self upon entering context manager."""
83+
return self
84+
85+
def __exit__(self, exc_type, exc_value, traceback) -> None:
86+
"""Default behavior is to do nothing, as connections may be shared."""
87+
8188
@abstractmethod
8289
def clone(
8390
self,
@@ -97,6 +104,12 @@ def init(self, uri: StorageURI) -> None:
97104
def close(self) -> None:
98105
"""Closes any active database or HTTP connections."""
99106

107+
def close_on_exit(self) -> None:
108+
"""Closes any active database or HTTP connections, called on Session exit or
109+
for test cleanup only, as some Metastore implementations may handle this
110+
differently."""
111+
self.close()
112+
100113
def cleanup_tables(self, temp_table_names: list[str]) -> None:
101114
"""Cleanup temp tables."""
102115

src/datachain/data_storage/sqlite.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
)
1616

1717
import sqlalchemy
18-
from attrs import frozen
1918
from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
2019
from sqlalchemy.dialects import sqlite
2120
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
@@ -40,6 +39,7 @@
4039

4140
if TYPE_CHECKING:
4241
from sqlalchemy.dialects.sqlite import Insert
42+
from sqlalchemy.engine.base import Engine
4343
from sqlalchemy.schema import SchemaItem
4444
from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause
4545
from sqlalchemy.sql.selectable import Select
@@ -52,6 +52,8 @@
5252
RETRY_MAX_TIMES = 10
5353
RETRY_FACTOR = 2
5454

55+
DETECT_TYPES = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
56+
5557
Column = Union[str, "ColumnClause[Any]", "TextClause"]
5658

5759
datachain.sql.sqlite.setup()
@@ -80,26 +82,41 @@ def wrapper(*args, **kwargs):
8082
return wrapper
8183

8284

83-
@frozen
8485
class SQLiteDatabaseEngine(DatabaseEngine):
8586
dialect = sqlite_dialect
8687

8788
db: sqlite3.Connection
8889
db_file: Optional[str]
90+
is_closed: bool
91+
92+
def __init__(
93+
self,
94+
engine: "Engine",
95+
metadata: "MetaData",
96+
db: sqlite3.Connection,
97+
db_file: Optional[str] = None,
98+
):
99+
self.engine = engine
100+
self.metadata = metadata
101+
self.db = db
102+
self.db_file = db_file
103+
self.is_closed = False
89104

90105
@classmethod
91106
def from_db_file(cls, db_file: Optional[str] = None) -> "SQLiteDatabaseEngine":
92-
detect_types = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
107+
return cls(*cls._connect(db_file=db_file))
93108

109+
@staticmethod
110+
def _connect(db_file: Optional[str] = None):
94111
try:
95112
if db_file == ":memory:":
96113
# Enable multithreaded usage of the same in-memory db
97114
db = sqlite3.connect(
98-
"file::memory:?cache=shared", uri=True, detect_types=detect_types
115+
"file::memory:?cache=shared", uri=True, detect_types=DETECT_TYPES
99116
)
100117
else:
101118
db = sqlite3.connect(
102-
db_file or DataChainDir.find().db, detect_types=detect_types
119+
db_file or DataChainDir.find().db, detect_types=DETECT_TYPES
103120
)
104121
create_user_defined_sql_functions(db)
105122
engine = sqlalchemy.create_engine(
@@ -118,7 +135,7 @@ def from_db_file(cls, db_file: Optional[str] = None) -> "SQLiteDatabaseEngine":
118135

119136
load_usearch_extension(db)
120137

121-
return cls(engine, MetaData(), db, db_file)
138+
return engine, MetaData(), db, db_file
122139
except RuntimeError:
123140
raise DataChainError("Can't connect to SQLite DB") from None
124141

@@ -138,13 +155,26 @@ def clone_params(self) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
138155
{},
139156
)
140157

158+
def _reconnect(self) -> None:
159+
if not self.is_closed:
160+
raise RuntimeError("Cannot reconnect on still-open DB!")
161+
engine, metadata, db, db_file = self._connect(db_file=self.db_file)
162+
self.engine = engine
163+
self.metadata = metadata
164+
self.db = db
165+
self.db_file = db_file
166+
self.is_closed = False
167+
141168
@retry_sqlite_locks
142169
def execute(
143170
self,
144171
query,
145172
cursor: Optional[sqlite3.Cursor] = None,
146173
conn=None,
147174
) -> sqlite3.Cursor:
175+
if self.is_closed:
176+
# Reconnect in case of being closed previously.
177+
self._reconnect()
148178
if cursor is not None:
149179
result = cursor.execute(*self.compile_to_args(query))
150180
elif conn is not None:
@@ -179,6 +209,7 @@ def cursor(self, factory=None):
179209

180210
def close(self) -> None:
181211
self.db.close()
212+
self.is_closed = True
182213

183214
@contextmanager
184215
def transaction(self):
@@ -359,6 +390,10 @@ def __init__(
359390

360391
self._init_tables()
361392

393+
def __exit__(self, exc_type, exc_value, traceback) -> None:
394+
"""Close connection upon exit from context manager."""
395+
self.close()
396+
362397
def clone(
363398
self,
364399
uri: StorageURI = StorageURI(""),
@@ -521,6 +556,10 @@ def __init__(
521556

522557
self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
523558

559+
def __exit__(self, exc_type, exc_value, traceback) -> None:
560+
"""Close connection upon exit from context manager."""
561+
self.close()
562+
524563
def clone(self, use_new_connection: bool = False) -> "SQLiteWarehouse":
525564
return SQLiteWarehouse(self.id_generator.clone(), db=self.db.clone())
526565

src/datachain/data_storage/warehouse.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ class AbstractWarehouse(ABC, Serializable):
7070
def __init__(self, id_generator: "AbstractIDGenerator"):
7171
self.id_generator = id_generator
7272

73+
def __enter__(self) -> "AbstractWarehouse":
74+
return self
75+
76+
def __exit__(self, exc_type, exc_value, traceback) -> None:
77+
# Default behavior is to do nothing, as connections may be shared.
78+
pass
79+
7380
def cleanup_for_tests(self):
7481
"""Cleanup for tests."""
7582

@@ -158,6 +165,12 @@ def close(self) -> None:
158165
"""Closes any active database connections."""
159166
self.db.close()
160167

168+
def close_on_exit(self) -> None:
169+
"""Closes any active database or HTTP connections, called on Session exit or
170+
for test cleanup only, as some Warehouse implementations may handle this
171+
differently."""
172+
self.close()
173+
161174
#
162175
# Query Tables
163176
#

0 commit comments

Comments
 (0)