Skip to content

Commit c94dc94

Browse files
Add DB API using SQLAlchemy
Change-Id: I3a5b0bc8888ffcb4faab01b1f4d80c032a70f54c
1 parent 2a3e1f0 commit c94dc94

File tree

6 files changed

+233
-77
lines changed

6 files changed

+233
-77
lines changed

bluepyparallel/database.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Module"""
2+
import re
3+
4+
import pandas as pd
5+
from sqlalchemy import MetaData
6+
from sqlalchemy import Table
7+
from sqlalchemy import create_engine
8+
from sqlalchemy import insert
9+
from sqlalchemy import schema
10+
from sqlalchemy import select
11+
from sqlalchemy.engine.reflection import Inspector
12+
from sqlalchemy_utils import create_database
13+
from sqlalchemy_utils import database_exists
14+
15+
16+
class DataBase:
17+
"""A database API using SQLAlchemy."""
18+
19+
index_col = "df_index"
20+
_url_pattern = r"[a-zA-Z0-9_\-\+]+://.*"
21+
22+
def __init__(self, url, *args, create=False, **kwargs):
23+
if not re.match(self._url_pattern, str(url)):
24+
url = "sqlite:///" + str(url)
25+
26+
self.engine = create_engine(url, *args, **kwargs)
27+
28+
if create and not database_exists(self.engine.url):
29+
create_database(self.engine.url)
30+
31+
self.connection = self.engine.connect()
32+
self.metadata = None
33+
self.table = None
34+
35+
def get_url(self):
36+
return self.engine.url
37+
38+
def create(self, df, table_name=None, schema_name=None):
39+
if table_name is None:
40+
table_name = "df"
41+
if schema_name is not None and schema_name not in self.connection.dialect.get_schema_names(
42+
self.connection
43+
):
44+
self.connection.execute(schema.CreateSchema(schema_name))
45+
new_df = df.loc[[]]
46+
new_df.to_sql(
47+
name=table_name,
48+
con=self.connection,
49+
schema=schema_name,
50+
if_exists="replace",
51+
index_label=self.index_col,
52+
)
53+
self.reflect(table_name, schema_name)
54+
55+
def exists(self, table_name, schema_name=None):
56+
inspector = Inspector.from_engine(self.engine)
57+
return table_name in inspector.get_table_names(schema=schema_name)
58+
59+
def reflect(self, table_name, schema_name=None):
60+
self.metadata = MetaData()
61+
self.table = Table(
62+
table_name,
63+
self.metadata,
64+
schema=schema_name,
65+
autoload=True,
66+
autoload_with=self.engine,
67+
)
68+
69+
def load(self):
70+
query = select([self.table])
71+
return pd.read_sql(query, self.connection, index_col=self.index_col)
72+
73+
def write(self, row_id, result=None, exception=None, **input_values):
74+
if result is not None:
75+
vals = result
76+
elif exception is not None:
77+
vals = {"exception": exception}
78+
else:
79+
return
80+
81+
query = insert(self.table).values(dict(**{self.index_col: row_id}, **vals, **input_values))
82+
self.connection.execute(query)

bluepyparallel/evaluator.py

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
"""Module to evaluate generic functions on rows of dataframe."""
22
import logging
3-
import sqlite3
43
import sys
54
import traceback
65
from functools import partial
7-
from pathlib import Path
86

9-
import pandas as pd
107
from tqdm import tqdm
118

9+
from bluepyparallel.database import DataBase
1210
from bluepyparallel.parallel import init_parallel_factory
1311

1412
logger = logging.getLogger(__name__)
1513

1614

17-
def _try_evaluation(task, evaluation_function, db_filename, func_args, func_kwargs):
15+
def _try_evaluation(task, evaluation_function, func_args, func_kwargs):
1816
"""Encapsulate the evaluation function into a try/except and isolate to record exceptions."""
1917
task_id, task_args = task
2018
try:
@@ -24,47 +22,16 @@ def _try_evaluation(task, evaluation_function, db_filename, func_args, func_kwar
2422
result = None
2523
exception = "".join(traceback.format_exception(*sys.exc_info()))
2624
logger.exception("Exception for ID=%s: %s", task_id, exception)
27-
28-
# Save the results into the DB
29-
if db_filename is not None:
30-
_write_to_sql(db_filename, task_id, result, exception)
3125
return task_id, result, exception
3226

3327

34-
def _create_database(df, db_filename="db.sql"):
35-
"""Create a sqlite database from dataframe."""
36-
with sqlite3.connect(str(db_filename)) as db:
37-
df.to_sql("df", db, if_exists="replace", index_label="df_index")
38-
39-
40-
def _load_database_to_dataframe(db_filename="db.sql"):
41-
"""Load an SQL database and construct the dataframe."""
42-
with sqlite3.connect(str(db_filename)) as db:
43-
return pd.read_sql("SELECT * FROM df", db, index_col="df_index")
44-
45-
46-
def _write_to_sql(db_filename, task_id, results, exception):
47-
"""Write row data to SQL."""
48-
with sqlite3.connect(str(db_filename)) as db:
49-
if results is not None:
50-
keys, vals = zip(*results.items())
51-
query_keys = ", ".join([f"{k}=?" for k in keys])
52-
else:
53-
query_keys = "exception=?"
54-
vals = [exception]
55-
db.execute(
56-
"UPDATE df SET " + query_keys + " WHERE df_index=?",
57-
list(vals) + [task_id],
58-
)
59-
60-
6128
def evaluate(
6229
df,
6330
evaluation_function,
6431
new_columns=None,
6532
resume=False,
6633
parallel_factory=None,
67-
db_filename=None,
34+
db_url=None,
6835
func_args=None,
6936
func_kwargs=None,
7037
):
@@ -80,9 +47,11 @@ def evaluate(
8047
resume (bool): if True, it will use only compute the empty rows of the database,
8148
if False, it will ecrase or generate the database.
8249
parallel_factory (ParallelFactory): parallel factory instance.
83-
db_filename (str): if a file path is given, SQL backend will be enabled and will use this
84-
path for the SQLite database. Should not be used when evaluations are numerous and
85-
fast, in order to avoid the overhead of communication with SQL database.
50+
db_url (str): should be DB URL that can be interpreted by SQLAlchemy or can be a file path
51+
that is interpreted as a SQLite database. If an URL is given, the SQL backend will be
52+
enabled to store results and allowing future resume. Should not be used when
53+
evaluations are numerous and fast, in order to avoid the overhead of communication with
54+
SQL database.
8655
func_args (list): the arguments to pass to the evaluation_function.
8756
func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
8857
@@ -115,12 +84,16 @@ def evaluate(
11584
to_evaluate[new_column[0]] = new_column[1]
11685

11786
# Create the database if required and get the task ids to run
118-
if db_filename is None:
87+
if db_url is None:
11988
logger.info("Not using SQL backend to save iterations")
120-
elif resume:
121-
logger.info("Load data from SQL database")
122-
if Path(db_filename).exists():
123-
previous_results = _load_database_to_dataframe(db_filename=db_filename)
89+
db = None
90+
else:
91+
db = DataBase(db_url)
92+
93+
if resume and db.exists("df"):
94+
logger.info("Load data from SQL database")
95+
db.reflect("df")
96+
previous_results = db.load()
12497
previous_idx = previous_results.index
12598
bad_cols = [
12699
col
@@ -134,10 +107,10 @@ def evaluate(
134107
to_evaluate.loc[previous_results.index] = previous_results.loc[previous_results.index]
135108
task_ids = task_ids.difference(previous_results.index)
136109
else:
137-
_create_database(to_evaluate, db_filename=db_filename)
138-
else:
139-
logger.info("Create SQL database")
140-
_create_database(to_evaluate, db_filename=db_filename)
110+
logger.info("Create SQL database")
111+
db.create(to_evaluate)
112+
113+
db_url = db.get_url()
141114

142115
# Log the number of tasks to run
143116
if len(task_ids) > 0:
@@ -153,16 +126,21 @@ def evaluate(
153126
eval_func = partial(
154127
_try_evaluation,
155128
evaluation_function=evaluation_function,
156-
db_filename=db_filename,
157129
func_args=func_args,
158130
func_kwargs=func_kwargs,
159131
)
160132

161133
# Split the data into rows
162-
arg_list = list(to_evaluate.loc[task_ids].to_dict("index").items())
134+
arg_list = list(to_evaluate.loc[task_ids, df.columns].to_dict("index").items())
163135

164136
try:
165137
for task_id, results, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
138+
# Save the results into the DB
139+
if db is not None:
140+
db.write(
141+
task_id, results, exception, **to_evaluate.loc[task_id, df.columns].to_dict()
142+
)
143+
166144
# Save the results into the DataFrame
167145
if results is not None:
168146
to_evaluate.loc[task_id, results.keys()] = list(results.values())

setup.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212
with open("README.rst", encoding="utf-8") as f:
1313
README = f.read()
1414

15+
reqs = [
16+
"pandas",
17+
"ipyparallel",
18+
"dask[distributed]>=2.30",
19+
"dask-mpi>=2.20",
20+
"sqlalchemy<1.4",
21+
"sqlalchemy-utils",
22+
"tqdm",
23+
]
24+
1525
doc_reqs = [
1626
"sphinx-bluebrain-theme",
1727
]
@@ -32,13 +42,7 @@
3242
"Source": "ssh://bbpcode.epfl.ch/cells/BluePyParallel",
3343
},
3444
license="BBP-internal-confidential",
35-
install_requires=[
36-
"pandas",
37-
"ipyparallel",
38-
"dask[distributed]>=2.30",
39-
"dask-mpi>=2.20",
40-
"tqdm",
41-
],
45+
install_requires=reqs,
4246
extras_require={
4347
"docs": doc_reqs,
4448
},

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33

44

55
@pytest.fixture
6-
def db_filename(tmpdir):
6+
def db_url(tmpdir):
77
return tmpdir / "db.sql"

tests/test_database.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Test the bluepyparallel.evaluator module"""
2+
# pylint: disable=redefined-outer-name
3+
import pandas as pd
4+
import pytest
5+
from sqlalchemy import MetaData
6+
from sqlalchemy import Table
7+
from sqlalchemy import create_engine
8+
from sqlalchemy import select
9+
10+
from bluepyparallel import database
11+
12+
URLS = ["/tmpdir/test.db", "sqlite:////tmpdir/test.db"]
13+
14+
15+
@pytest.fixture(params=URLS)
16+
def url(request, tmpdir):
17+
return request.param.replace("/tmpdir", str(tmpdir))
18+
19+
20+
@pytest.fixture
21+
def small_df():
22+
data = {"a": list(range(6)), "b": [str(i * 10) for i in range(6)], "exception": [None] * 6}
23+
idx = [f"idx_{(i + 1) * 2}" for i in range(6)]
24+
return pd.DataFrame(data, index=idx)
25+
26+
27+
@pytest.fixture()
28+
def small_db(url, small_df):
29+
db = database.DataBase(url)
30+
db.create(small_df)
31+
small_df.to_sql(
32+
name=db.table.name,
33+
con=db.connection,
34+
schema=db.table.schema,
35+
if_exists="replace",
36+
index_label=db.index_col,
37+
)
38+
return db
39+
40+
41+
class TestDataBase:
42+
"""Test the DataBase class."""
43+
44+
@pytest.mark.parametrize("table_name", [None, "df", "df_name"])
45+
@pytest.mark.parametrize("schema_name", [None])
46+
def test_create(self, url, small_df, table_name, schema_name):
47+
db = database.DataBase(url)
48+
db.create(small_df, table_name, schema_name)
49+
50+
# Check DB
51+
if url.startswith("/"):
52+
url = "sqlite:///" + url
53+
engine = create_engine(url)
54+
conn = engine.connect()
55+
metadata = MetaData()
56+
table = Table(
57+
table_name or "df",
58+
metadata,
59+
schema=schema_name,
60+
autoload=True,
61+
autoload_with=engine,
62+
)
63+
64+
# Check reflected table
65+
assert str(table.c.items()) == str(db.table.c.items())
66+
67+
# Check elements inserted into the DB
68+
query = select([table])
69+
res = conn.execute(query).fetchall()
70+
assert res == []
71+
72+
def test_exists(self, small_db):
73+
assert small_db.exists("df")
74+
assert not small_db.exists("UNKNOWN TABLE")
75+
76+
def test_load(self, small_df, small_db):
77+
res = small_db.load()
78+
79+
# Check DB
80+
assert res.equals(small_df)
81+
82+
def test_write(self, small_df, small_db):
83+
small_db.write("idx_100", result={"a": 1, "b": "test_1"})
84+
small_db.write("idx_101", exception="test exception")
85+
small_db.write("idx_102") # Should write nothing
86+
87+
# Check DB after write
88+
res = small_db.load()
89+
small_df.loc["idx_100", ["a", "b", "exception"]] = [1, "test_1", None]
90+
small_df.loc["idx_101", ["a", "b", "exception"]] = [None, None, "test exception"]
91+
assert res.equals(small_df)
92+
93+
def test_get_url(self, url, small_db):
94+
if url.startswith("/"):
95+
url = "sqlite:///" + url
96+
assert str(small_db.get_url()) == url

0 commit comments

Comments
 (0)