Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 62 additions & 52 deletions nibabel/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""


import contextlib
import os
from os.path import join as pjoin
import tempfile
Expand Down Expand Up @@ -74,7 +75,7 @@ def __getattribute__(self, name):
val = object.__getattribute__(self, name)
if name == 'series' and val is None:
val = []
with _db_nochange() as c:
with DB.readonly_cursor() as c:
c.execute("SELECT * FROM series WHERE study = ?", (self.uid, ))
cols = [el[0] for el in c.description]
for row in c:
Expand Down Expand Up @@ -106,7 +107,7 @@ def __getattribute__(self, name):
val = object.__getattribute__(self, name)
if name == 'storage_instances' and val is None:
val = []
with _db_nochange() as c:
with DB.readonly_cursor() as c:
query = """SELECT *
FROM storage_instance
WHERE series = ?
Expand Down Expand Up @@ -227,7 +228,7 @@ def __init__(self, d):
def __getattribute__(self, name):
val = object.__getattribute__(self, name)
if name == 'files' and val is None:
with _db_nochange() as c:
with DB.readonly_cursor() as c:
query = """SELECT directory, name
FROM file
WHERE storage_instance = ?
Expand All @@ -241,34 +242,6 @@ def dicom(self):
return pydicom.read_file(self.files[0])


class _db_nochange:
"""context guard for read-only database access"""

def __enter__(self):
self.c = DB.cursor()
return self.c

def __exit__(self, type, value, traceback):
if type is None:
self.c.close()
DB.rollback()


class _db_change:
"""context guard for database access requiring a commit"""

def __enter__(self):
self.c = DB.cursor()
return self.c

def __exit__(self, type, value, traceback):
if type is None:
self.c.close()
DB.commit()
else:
DB.rollback()


def _get_subdirs(base_dir, files_dict=None, followlinks=False):
dirs = []
for (dirpath, dirnames, filenames) in os.walk(base_dir, followlinks=followlinks):
Expand All @@ -288,7 +261,7 @@ def update_cache(base_dir, followlinks=False):
for d in dirs:
os.stat(d)
mtimes[d] = os.stat(d).st_mtime
with _db_nochange() as c:
with DB.readwrite_cursor() as c:
c.execute("SELECT path, mtime FROM directory")
db_mtimes = dict(c)
c.execute("SELECT uid FROM study")
Expand All @@ -297,7 +270,6 @@ def update_cache(base_dir, followlinks=False):
series = [row[0] for row in c]
c.execute("SELECT uid FROM storage_instance")
storage_instances = [row[0] for row in c]
with _db_change() as c:
for dir in sorted(mtimes.keys()):
if dir in db_mtimes and mtimes[dir] <= db_mtimes[dir]:
continue
Expand All @@ -316,7 +288,7 @@ def get_studies(base_dir=None, followlinks=False):
if base_dir is not None:
update_cache(base_dir, followlinks)
if base_dir is None:
with _db_nochange() as c:
with DB.readonly_cursor() as c:
c.execute("SELECT * FROM study")
studies = []
cols = [el[0] for el in c.description]
Expand All @@ -331,7 +303,7 @@ def get_studies(base_dir=None, followlinks=False):
WHERE uid IN (SELECT storage_instance
FROM file
WHERE directory = ?))"""
with _db_nochange() as c:
with DB.readonly_cursor() as c:
study_uids = {}
for dir in _get_subdirs(base_dir, followlinks=followlinks):
c.execute(query, (dir, ))
Expand Down Expand Up @@ -443,7 +415,7 @@ def _update_file(c, path, fname, studies, series, storage_instances):


def clear_cache():
with _db_change() as c:
with DB.readwrite_cursor() as c:
c.execute("DELETE FROM file")
c.execute("DELETE FROM directory")
c.execute("DELETE FROM storage_instance")
Expand Down Expand Up @@ -478,26 +450,64 @@ def clear_cache():
mtime INTEGER NOT NULL,
storage_instance TEXT DEFAULT NULL REFERENCES storage_instance,
PRIMARY KEY (directory, name))""")
DB_FNAME = pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite')
DB = None


def _init_db(verbose=True):
""" Initialize database """
if verbose:
logger.info('db filename: ' + DB_FNAME)
global DB
DB = sqlite3.connect(DB_FNAME, check_same_thread=False)
with _db_change() as c:
c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'")
if c.fetchone()[0] == 0:
logger.debug('create')
for q in CREATE_QUERIES:
c.execute(q)
class _DB:
def __init__(self, fname=None, verbose=True):
self.fname = fname or pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite')
self.verbose = verbose

@property
def session(self):
"""Get sqlite3 Connection

The connection is created on the first call of this property
"""
try:
return self._session
except AttributeError:
self._init_db()
return self._session

def _init_db(self):
if self.verbose:
logger.info('db filename: ' + self.fname)

self._session = sqlite3.connect(self.fname, isolation_level="EXCLUSIVE")
with self.readwrite_cursor() as c:
c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'")
if c.fetchone()[0] == 0:
logger.debug('create')
for q in CREATE_QUERIES:
c.execute(q)

def __repr__(self):
return f"<DFT {self.fname!r}>"

@contextlib.contextmanager
def readonly_cursor(self):
cursor = self.session.cursor()
try:
yield cursor
finally:
cursor.close()
self.session.rollback()

@contextlib.contextmanager
def readwrite_cursor(self):
cursor = self.session.cursor()
try:
yield cursor
except Exception:
self.session.rollback()
raise
finally:
cursor.close()
self.session.commit()


DB = None
if os.name == 'nt':
warnings.warn('dft needs FUSE which is not available for windows')
else:
_init_db()
# eof
DB = _DB()
74 changes: 51 additions & 23 deletions nibabel/tests/test_dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from os.path import join as pjoin, dirname
from io import BytesIO
from ..testing import suppress_warnings
import sqlite3

with suppress_warnings():
from .. import dft
Expand All @@ -29,26 +30,57 @@ def setUpModule():
raise unittest.SkipTest('Need pydicom for dft tests, skipping')


def test_init():
class Test_DBclass:
"""Some tests on the database manager class that don't get exercised through the API"""
def setup_method(self):
self._db = dft._DB(fname=":memory:", verbose=False)

def test_repr(self):
assert repr(self._db) == "<DFT ':memory:'>"

def test_cursor_conflict(self):
rwc = self._db.readwrite_cursor
statement = ("INSERT INTO directory (path, mtime) VALUES (?, ?)", ("/tmp", 0))
with pytest.raises(sqlite3.IntegrityError):
# Whichever exits first will commit and make the second violate uniqueness
with rwc() as c1, rwc() as c2:
c1.execute(*statement)
c2.execute(*statement)


@pytest.fixture
def db(monkeypatch):
"""Build a dft database in memory to avoid cross-process races
and not modify the host filesystem."""
database = dft._DB(fname=":memory:")
monkeypatch.setattr(dft, "DB", database)
yield database


def test_init(db):
dft.clear_cache()
dft.update_cache(data_dir)
# Verify a second update doesn't crash
dft.update_cache(data_dir)


def test_study():
studies = dft.get_studies(data_dir)
assert len(studies) == 1
assert (studies[0].uid ==
'1.3.12.2.1107.5.2.32.35119.30000010011408520750000000022')
assert studies[0].date == '20100114'
assert studies[0].time == '121314.000000'
assert studies[0].comments == 'dft study comments'
assert studies[0].patient_name == 'dft patient name'
assert studies[0].patient_id == '1234'
assert studies[0].patient_birth_date == '19800102'
assert studies[0].patient_sex == 'F'


def test_series():
def test_study(db):
# First pass updates the cache, second pass reads it out
for base_dir in (data_dir, None):
studies = dft.get_studies(base_dir)
assert len(studies) == 1
assert (studies[0].uid ==
'1.3.12.2.1107.5.2.32.35119.30000010011408520750000000022')
assert studies[0].date == '20100114'
assert studies[0].time == '121314.000000'
assert studies[0].comments == 'dft study comments'
assert studies[0].patient_name == 'dft patient name'
assert studies[0].patient_id == '1234'
assert studies[0].patient_birth_date == '19800102'
assert studies[0].patient_sex == 'F'


def test_series(db):
studies = dft.get_studies(data_dir)
assert len(studies[0].series) == 1
ser = studies[0].series[0]
Expand All @@ -62,7 +94,7 @@ def test_series():
assert ser.bits_stored == 12


def test_storage_instances():
def test_storage_instances(db):
studies = dft.get_studies(data_dir)
sis = studies[0].series[0].storage_instances
assert len(sis) == 2
Expand All @@ -74,19 +106,15 @@ def test_storage_instances():
'1.3.12.2.1107.5.2.32.35119.2010011420300180088599504.1')


def test_storage_instance():
pass


@unittest.skipUnless(have_pil, 'could not import PIL.Image')
def test_png():
def test_png(db):
studies = dft.get_studies(data_dir)
data = studies[0].series[0].as_png()
im = PImage.open(BytesIO(data))
assert im.size == (256, 256)


def test_nifti():
def test_nifti(db):
studies = dft.get_studies(data_dir)
data = studies[0].series[0].as_nifti()
assert len(data) == 352 + 2 * 256 * 256 * 2
Expand Down