diff --git a/oil_library/__init__.py b/oil_library/__init__.py index 695797e..ba472d1 100644 --- a/oil_library/__init__.py +++ b/oil_library/__init__.py @@ -5,7 +5,7 @@ import os import sys -import logging +import logging; logger = logging.getLogger(__name__) from pkg_resources import get_distribution @@ -21,11 +21,16 @@ # # currently, the DB is created and located when package is installed # -_oillib_path = os.path.dirname(__file__) -__module_folder__ = __file__.split(os.sep)[-2] -_db_file = 'OilLib.db' -_db_file_path = os.path.join(_oillib_path, _db_file) +def get_oil_db_path(): + """ + Get the path to the Sqlite3 oil database + """ + from pkg_resources import resource_filename, resource_exists + + if not resource_exists(__name__, 'OilLib.db'): + logger.warning('OilLib.db does not exist') + return resource_filename(__name__, 'OilLib.db') def _get_db_session(): 'we can call this from scripts to access valid DBSession' @@ -35,12 +40,11 @@ def _get_db_session(): try: eng = session.get_bind() - if eng.url.database.split(os.path.sep)[-2:] != [__module_folder__, - _db_file]: + if os.path.realpath(eng.url.database) != os.path.realpath(get_oil_db_path()): raise UnboundExecutionError except UnboundExecutionError: - session.bind = create_engine('sqlite:///' + _db_file_path) + session.bind = create_engine('sqlite:///' + get_oil_db_path()) return session diff --git a/oil_library/tests/__init__.py b/oil_library/tests/__init__.py new file mode 100644 index 0000000..e69de29