Skip to content

Commit 8809690

Browse files
authored
Merge pull request #8 from danlooo/add-list-traits
Add list traits
2 parents e3538f7 + e7bac5c commit 8809690

File tree

2 files changed

+71
-14
lines changed

2 files changed

+71
-14
lines changed

pathtraits/db.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,42 @@ class TraitsDB:
1919
cursor = None
2020
traits = []
2121

22+
@staticmethod
23+
def row_factory(cursor, row):
24+
"""
25+
Turns sqlite3 row into a dict. Only works on a single row at once.
26+
27+
:param cursor: Description
28+
:param row: Description
29+
"""
30+
fields = [column[0] for column in cursor.description]
31+
res = dict(zip(fields, row))
32+
return res
33+
34+
@staticmethod
35+
def merge_rows(rows: list):
36+
"""
37+
Merges a list of row dicts of a path into a sinle dict by pooling trait keys
38+
39+
:param res: Description
40+
"""
41+
res = {}
42+
for row in rows:
43+
for k, v in row.items():
44+
# pylint: disable=C0201
45+
if k in res.keys() and v not in res[k]:
46+
res[k].append(v)
47+
else:
48+
res[k] = [v]
49+
# simplify lists with just one element
50+
# ensure fixed order of list entries
51+
res = {k: sorted(v) if len(v) > 1 else v[0] for k, v in res.items()}
52+
return res
53+
2254
def __init__(self, db_path):
2355
db_path = os.path.join(db_path)
2456
self.cursor = sqlite3.connect(db_path, autocommit=True).cursor()
57+
self.cursor.row_factory = TraitsDB.row_factory
2558

2659
init_path_table_query = """
2760
CREATE TABLE IF NOT EXISTS path (
@@ -67,15 +100,19 @@ def get(self, table, cols="*", condition=None, **kwargs):
67100
for (k, v) in kwargs.items()
68101
}
69102
condition = " AND ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()])
70-
get_row_query = f"SELECT {cols} FROM {table} WHERE {condition} LIMIT 1;"
103+
get_row_query = f"SELECT {cols} FROM {table} WHERE {condition};"
71104
response = self.execute(get_row_query)
72-
values = response.fetchone()
73105

74-
if values is None:
106+
if response is None:
75107
return None
76108

77-
keys = map(lambda x: x[0], response.description)
78-
res = dict(zip(keys, values))
109+
res = response.fetchall()
110+
if len(res) == 1:
111+
return res[0]
112+
113+
if isinstance(res, list) and len(res) > 1:
114+
res = TraitsDB.merge_rows(res)
115+
79116
return res
80117

81118
def put_path_id(self, path):
@@ -130,13 +167,14 @@ def sql_type(value_type):
130167
sql_type = sqlite_types.get(value_type, "TEXT")
131168
return sql_type
132169

133-
def put(self, table, condition=None, **kwargs):
170+
def put(self, table, condition=None, update=True, **kwargs):
134171
"""
135172
Puts a row into a table. Creates a row if not present, updates otherwise.
173+
:param update; overwrite existing data
136174
"""
137175
escaped_kwargs = {k: TraitsDB.escape(v) for (k, v) in kwargs.items()}
138176

139-
if self.get(table, condition=condition, **kwargs):
177+
if update and self.get(table, condition=condition, **kwargs):
140178
# update
141179
values = " , ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()])
142180
if condition:
@@ -193,7 +231,7 @@ def update_traits(self):
193231
ORDER BY name;
194232
"""
195233
traits = self.execute(get_traits_query).fetchall()
196-
self.traits = [x[0] for x in traits]
234+
self.traits = [list(x.values())[0] for x in traits]
197235
self.put_data_view()
198236

199237
def create_trait_table(self, trait_name, value_type):
@@ -223,7 +261,7 @@ def create_trait_table(self, trait_name, value_type):
223261
self.execute(add_table_query)
224262
self.update_traits()
225263

226-
def put_trait(self, path_id, trait_name, value):
264+
def put_trait(self, path_id, trait_name, value, update=True):
227265
"""
228266
Put a trait to the database
229267
@@ -233,7 +271,7 @@ def put_trait(self, path_id, trait_name, value):
233271
:param value: trait value
234272
"""
235273
kwargs = {"path": path_id, trait_name: value}
236-
self.put(trait_name, condition=f"path = {path_id}", **kwargs)
274+
self.put(trait_name, condition=f"path = {path_id}", update=update, **kwargs)
237275

238276
def add_pathpair(self, pair: PathPair):
239277
"""
@@ -259,8 +297,17 @@ def add_pathpair(self, pair: PathPair):
259297
for k, v in traits.items():
260298
# same YAML key might have different value types
261299
# Therefore, add type to key
262-
k = f"{k}_{TraitsDB.sql_type(type(v))}"
300+
301+
# get element type for list
302+
# add: handle lists with mixed element type
303+
t = type(v[0]) if isinstance(v, list) else type(v)
304+
k = f"{k}_{TraitsDB.sql_type(t)}"
263305
if k not in self.traits:
264-
self.create_trait_table(k, type(v))
306+
t = type(v[0]) if isinstance(v, list) else type(v)
307+
self.create_trait_table(k, t)
265308
if k in self.traits:
266-
self.put_trait(path_id, k, v)
309+
if isinstance(v, list):
310+
for vv in v:
311+
self.put_trait(path_id, k, vv, update=False)
312+
else:
313+
self.put_trait(path_id, k, v)

test/test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,18 @@ def test_example(self):
3131
for k, v in target.items():
3232
self.assertEqual(source[k], v)
3333

34+
source = pathtraits.access.get_dict(db, "test/example/EU")
35+
target = {
36+
"description_TEXT": "EU data",
37+
"is_example_BOOL": True,
38+
"score_REAL": 3.5,
39+
"users_TEXT": ["dloos", "fgans"],
40+
}
41+
for k, v in target.items():
42+
self.assertEqual(source[k], v)
43+
3444
source = len(db.execute("SELECT * FROM data;").fetchall())
35-
target = 4
45+
target = 6
3646
self.assertEqual(source, target)
3747
os.remove(db_path)
3848

0 commit comments

Comments
 (0)