11
11
"""
12
12
13
13
14
+ import contextlib
14
15
import os
15
16
from os .path import join as pjoin
16
17
import tempfile
@@ -74,7 +75,7 @@ def __getattribute__(self, name):
74
75
val = object .__getattribute__ (self , name )
75
76
if name == 'series' and val is None :
76
77
val = []
77
- with _db_nochange () as c :
78
+ with DB . readonly_cursor () as c :
78
79
c .execute ("SELECT * FROM series WHERE study = ?" , (self .uid , ))
79
80
cols = [el [0 ] for el in c .description ]
80
81
for row in c :
@@ -106,7 +107,7 @@ def __getattribute__(self, name):
106
107
val = object .__getattribute__ (self , name )
107
108
if name == 'storage_instances' and val is None :
108
109
val = []
109
- with _db_nochange () as c :
110
+ with DB . readonly_cursor () as c :
110
111
query = """SELECT *
111
112
FROM storage_instance
112
113
WHERE series = ?
@@ -227,7 +228,7 @@ def __init__(self, d):
227
228
def __getattribute__ (self , name ):
228
229
val = object .__getattribute__ (self , name )
229
230
if name == 'files' and val is None :
230
- with _db_nochange () as c :
231
+ with DB . readonly_cursor () as c :
231
232
query = """SELECT directory, name
232
233
FROM file
233
234
WHERE storage_instance = ?
@@ -241,34 +242,6 @@ def dicom(self):
241
242
return pydicom .read_file (self .files [0 ])
242
243
243
244
244
- class _db_nochange :
245
- """context guard for read-only database access"""
246
-
247
- def __enter__ (self ):
248
- self .c = DB .cursor ()
249
- return self .c
250
-
251
- def __exit__ (self , type , value , traceback ):
252
- if type is None :
253
- self .c .close ()
254
- DB .rollback ()
255
-
256
-
257
- class _db_change :
258
- """context guard for database access requiring a commit"""
259
-
260
- def __enter__ (self ):
261
- self .c = DB .cursor ()
262
- return self .c
263
-
264
- def __exit__ (self , type , value , traceback ):
265
- if type is None :
266
- self .c .close ()
267
- DB .commit ()
268
- else :
269
- DB .rollback ()
270
-
271
-
272
245
def _get_subdirs (base_dir , files_dict = None , followlinks = False ):
273
246
dirs = []
274
247
for (dirpath , dirnames , filenames ) in os .walk (base_dir , followlinks = followlinks ):
@@ -288,7 +261,7 @@ def update_cache(base_dir, followlinks=False):
288
261
for d in dirs :
289
262
os .stat (d )
290
263
mtimes [d ] = os .stat (d ).st_mtime
291
- with _db_nochange () as c :
264
+ with DB . readwrite_cursor () as c :
292
265
c .execute ("SELECT path, mtime FROM directory" )
293
266
db_mtimes = dict (c )
294
267
c .execute ("SELECT uid FROM study" )
@@ -297,7 +270,6 @@ def update_cache(base_dir, followlinks=False):
297
270
series = [row [0 ] for row in c ]
298
271
c .execute ("SELECT uid FROM storage_instance" )
299
272
storage_instances = [row [0 ] for row in c ]
300
- with _db_change () as c :
301
273
for dir in sorted (mtimes .keys ()):
302
274
if dir in db_mtimes and mtimes [dir ] <= db_mtimes [dir ]:
303
275
continue
@@ -316,7 +288,7 @@ def get_studies(base_dir=None, followlinks=False):
316
288
if base_dir is not None :
317
289
update_cache (base_dir , followlinks )
318
290
if base_dir is None :
319
- with _db_nochange () as c :
291
+ with DB . readonly_cursor () as c :
320
292
c .execute ("SELECT * FROM study" )
321
293
studies = []
322
294
cols = [el [0 ] for el in c .description ]
@@ -331,7 +303,7 @@ def get_studies(base_dir=None, followlinks=False):
331
303
WHERE uid IN (SELECT storage_instance
332
304
FROM file
333
305
WHERE directory = ?))"""
334
- with _db_nochange () as c :
306
+ with DB . readonly_cursor () as c :
335
307
study_uids = {}
336
308
for dir in _get_subdirs (base_dir , followlinks = followlinks ):
337
309
c .execute (query , (dir , ))
@@ -443,7 +415,7 @@ def _update_file(c, path, fname, studies, series, storage_instances):
443
415
444
416
445
417
def clear_cache ():
446
- with _db_change () as c :
418
+ with DB . readwrite_cursor () as c :
447
419
c .execute ("DELETE FROM file" )
448
420
c .execute ("DELETE FROM directory" )
449
421
c .execute ("DELETE FROM storage_instance" )
@@ -478,26 +450,56 @@ def clear_cache():
478
450
mtime INTEGER NOT NULL,
479
451
storage_instance TEXT DEFAULT NULL REFERENCES storage_instance,
480
452
PRIMARY KEY (directory, name))""" )
481
- DB_FNAME = pjoin (tempfile .gettempdir (), f'dft.{ getpass .getuser ()} .sqlite' )
482
- DB = None
483
453
484
454
485
- def _init_db (verbose = True ):
486
- """ Initialize database """
487
- if verbose :
488
- logger .info ('db filename: ' + DB_FNAME )
489
- global DB
490
- DB = sqlite3 .connect (DB_FNAME , check_same_thread = False )
491
- with _db_change () as c :
492
- c .execute ("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'" )
493
- if c .fetchone ()[0 ] == 0 :
494
- logger .debug ('create' )
495
- for q in CREATE_QUERIES :
496
- c .execute (q )
455
+ class _DB :
456
+ def __init__ (self , verbose = True ):
457
+ self .fname = pjoin (tempfile .gettempdir (), f'dft.{ getpass .getuser ()} .sqlite' )
458
+ if verbose :
459
+ logger .info ('db filename: ' + self .fname )
460
+ self .session = sqlite3 .connect (self .fname , isolation_level = "EXCLUSIVE" )
461
+ try :
462
+ with self .readwrite_cursor () as c :
463
+ c .execute ("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'" )
464
+ if c .fetchone ()[0 ] == 0 :
465
+ logger .debug ('create' )
466
+ for q in CREATE_QUERIES :
467
+ c .execute (q )
468
+ except sqlite3 .OperationalError as e :
469
+ # Race condition, but verify that tables actually got created
470
+ with self .readonly_cursor () as c :
471
+ c .execute ("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'" )
472
+ if c .fetchone ()[0 ] != 5 :
473
+ raise DFTError ("Could not construct DFT database" ) from e
474
+
475
+ def __repr__ (self ):
476
+ return f"<DFT { self .fname !r} >"
477
+
478
+ @contextlib .contextmanager
479
+ def readonly_cursor (self ):
480
+ cursor = self .session .cursor ()
481
+ try :
482
+ yield cursor
483
+ finally :
484
+ cursor .close ()
485
+ self .session .rollback ()
486
+
487
+ @contextlib .contextmanager
488
+ def readwrite_cursor (self ):
489
+ cursor = self .session .cursor ()
490
+ try :
491
+ yield cursor
492
+ except Exception :
493
+ self .session .rollback ()
494
+ raise
495
+ finally :
496
+ cursor .close ()
497
+ self .session .commit ()
497
498
498
499
500
+ DB = None
499
501
if os .name == 'nt' :
500
502
warnings .warn ('dft needs FUSE which is not available for windows' )
501
503
else :
502
- _init_db ()
504
+ DB = _DB ()
503
505
# eof
0 commit comments