Skip to content

Commit b84e0e1

Browse files
committed
RF: Write DFT database manager as object
This adds a dft._DB class that handles the _init_db and _db_(no)change functions. The default instance remains at dft.DB, but this allows us to create new instances for testing purposes.
1 parent 8cf190d commit b84e0e1

File tree

1 file changed

+62
-52
lines changed

1 file changed

+62
-52
lines changed

nibabel/dft.py

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212

1313

14+
import contextlib
1415
import os
1516
from os.path import join as pjoin
1617
import tempfile
@@ -74,7 +75,7 @@ def __getattribute__(self, name):
7475
val = object.__getattribute__(self, name)
7576
if name == 'series' and val is None:
7677
val = []
77-
with _db_nochange() as c:
78+
with DB.readonly_cursor() as c:
7879
c.execute("SELECT * FROM series WHERE study = ?", (self.uid, ))
7980
cols = [el[0] for el in c.description]
8081
for row in c:
@@ -106,7 +107,7 @@ def __getattribute__(self, name):
106107
val = object.__getattribute__(self, name)
107108
if name == 'storage_instances' and val is None:
108109
val = []
109-
with _db_nochange() as c:
110+
with DB.readonly_cursor() as c:
110111
query = """SELECT *
111112
FROM storage_instance
112113
WHERE series = ?
@@ -227,7 +228,7 @@ def __init__(self, d):
227228
def __getattribute__(self, name):
228229
val = object.__getattribute__(self, name)
229230
if name == 'files' and val is None:
230-
with _db_nochange() as c:
231+
with DB.readonly_cursor() as c:
231232
query = """SELECT directory, name
232233
FROM file
233234
WHERE storage_instance = ?
@@ -241,34 +242,6 @@ def dicom(self):
241242
return pydicom.read_file(self.files[0])
242243

243244

244-
class _db_nochange:
245-
"""context guard for read-only database access"""
246-
247-
def __enter__(self):
248-
self.c = DB.cursor()
249-
return self.c
250-
251-
def __exit__(self, type, value, traceback):
252-
if type is None:
253-
self.c.close()
254-
DB.rollback()
255-
256-
257-
class _db_change:
258-
"""context guard for database access requiring a commit"""
259-
260-
def __enter__(self):
261-
self.c = DB.cursor()
262-
return self.c
263-
264-
def __exit__(self, type, value, traceback):
265-
if type is None:
266-
self.c.close()
267-
DB.commit()
268-
else:
269-
DB.rollback()
270-
271-
272245
def _get_subdirs(base_dir, files_dict=None, followlinks=False):
273246
dirs = []
274247
for (dirpath, dirnames, filenames) in os.walk(base_dir, followlinks=followlinks):
@@ -288,7 +261,7 @@ def update_cache(base_dir, followlinks=False):
288261
for d in dirs:
289262
os.stat(d)
290263
mtimes[d] = os.stat(d).st_mtime
291-
with _db_nochange() as c:
264+
with DB.readwrite_cursor() as c:
292265
c.execute("SELECT path, mtime FROM directory")
293266
db_mtimes = dict(c)
294267
c.execute("SELECT uid FROM study")
@@ -297,7 +270,6 @@ def update_cache(base_dir, followlinks=False):
297270
series = [row[0] for row in c]
298271
c.execute("SELECT uid FROM storage_instance")
299272
storage_instances = [row[0] for row in c]
300-
with _db_change() as c:
301273
for dir in sorted(mtimes.keys()):
302274
if dir in db_mtimes and mtimes[dir] <= db_mtimes[dir]:
303275
continue
@@ -316,7 +288,7 @@ def get_studies(base_dir=None, followlinks=False):
316288
if base_dir is not None:
317289
update_cache(base_dir, followlinks)
318290
if base_dir is None:
319-
with _db_nochange() as c:
291+
with DB.readonly_cursor() as c:
320292
c.execute("SELECT * FROM study")
321293
studies = []
322294
cols = [el[0] for el in c.description]
@@ -331,7 +303,7 @@ def get_studies(base_dir=None, followlinks=False):
331303
WHERE uid IN (SELECT storage_instance
332304
FROM file
333305
WHERE directory = ?))"""
334-
with _db_nochange() as c:
306+
with DB.readonly_cursor() as c:
335307
study_uids = {}
336308
for dir in _get_subdirs(base_dir, followlinks=followlinks):
337309
c.execute(query, (dir, ))
@@ -443,7 +415,7 @@ def _update_file(c, path, fname, studies, series, storage_instances):
443415

444416

445417
def clear_cache():
446-
with _db_change() as c:
418+
with DB.readwrite_cursor() as c:
447419
c.execute("DELETE FROM file")
448420
c.execute("DELETE FROM directory")
449421
c.execute("DELETE FROM storage_instance")
@@ -478,26 +450,64 @@ def clear_cache():
478450
mtime INTEGER NOT NULL,
479451
storage_instance TEXT DEFAULT NULL REFERENCES storage_instance,
480452
PRIMARY KEY (directory, name))""")
481-
DB_FNAME = pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite')
482-
DB = None
483453

484454

485-
def _init_db(verbose=True):
486-
""" Initialize database """
487-
if verbose:
488-
logger.info('db filename: ' + DB_FNAME)
489-
global DB
490-
DB = sqlite3.connect(DB_FNAME, check_same_thread=False)
491-
with _db_change() as c:
492-
c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'")
493-
if c.fetchone()[0] == 0:
494-
logger.debug('create')
495-
for q in CREATE_QUERIES:
496-
c.execute(q)
455+
class _DB:
456+
def __init__(self, fname=None, verbose=True):
457+
self.fname = fname or pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite')
458+
self.verbose = verbose
459+
460+
@property
461+
def session(self):
462+
"""Get sqlite3 Connection
463+
464+
The connection is created on the first call of this property
465+
"""
466+
try:
467+
return self._session
468+
except AttributeError:
469+
self._init_db()
470+
return self._session
471+
472+
def _init_db(self):
473+
if self.verbose:
474+
logger.info('db filename: ' + self.fname)
475+
476+
self._session = sqlite3.connect(self.fname, isolation_level="EXCLUSIVE")
477+
with self.readwrite_cursor() as c:
478+
c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'")
479+
if c.fetchone()[0] == 0:
480+
logger.debug('create')
481+
for q in CREATE_QUERIES:
482+
c.execute(q)
483+
484+
def __repr__(self):
485+
return f"<DFT {self.fname!r}>"
486+
487+
@contextlib.contextmanager
488+
def readonly_cursor(self):
489+
cursor = self.session.cursor()
490+
try:
491+
yield cursor
492+
finally:
493+
cursor.close()
494+
self.session.rollback()
495+
496+
@contextlib.contextmanager
497+
def readwrite_cursor(self):
498+
cursor = self.session.cursor()
499+
try:
500+
yield cursor
501+
except Exception:
502+
self.session.rollback()
503+
raise
504+
finally:
505+
cursor.close()
506+
self.session.commit()
497507

498508

509+
DB = None
499510
if os.name == 'nt':
500511
warnings.warn('dft needs FUSE which is not available for windows')
501512
else:
502-
_init_db()
503-
# eof
513+
DB = _DB()

0 commit comments

Comments
 (0)