Skip to content

Commit 6783100

Browse files
committed
Add get list traits
1 parent a474f71 commit 6783100

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

pathtraits/db.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,40 @@ class TraitsDB:
1919
cursor = None
2020
traits = []
2121

22+
@staticmethod
23+
def row_factory(cursor, row):
24+
"""
25+
Turns sqlite3 row into a dict. Only works on a single row at once.
26+
27+
:param cursor: Description
28+
:param row: Description
29+
"""
30+
fields = [column[0] for column in cursor.description]
31+
res = {key: value for key, value in zip(fields, row)}
32+
return res
33+
34+
@staticmethod
35+
def merge_rows(rows: list):
36+
"""
37+
Merges a list of row dicts of a path into a sinle dict by pooling trait keys
38+
39+
:param res: Description
40+
"""
41+
res = {}
42+
for row in rows:
43+
for k, v in row.items():
44+
if k in res.keys() and v not in res[k]:
45+
res[k].append(v)
46+
else:
47+
res[k] = [v]
48+
# simplify lists with just one element
49+
res = {k: v if len(v) > 1 else v[0] for k, v in res.items()}
50+
return res
51+
2252
def __init__(self, db_path):
2353
db_path = os.path.join(db_path)
2454
self.cursor = sqlite3.connect(db_path, autocommit=True).cursor()
55+
self.cursor.row_factory = TraitsDB.row_factory
2556

2657
init_path_table_query = """
2758
CREATE TABLE IF NOT EXISTS path (
@@ -67,15 +98,19 @@ def get(self, table, cols="*", condition=None, **kwargs):
6798
for (k, v) in kwargs.items()
6899
}
69100
condition = " AND ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()])
70-
get_row_query = f"SELECT {cols} FROM {table} WHERE {condition} LIMIT 1;"
101+
get_row_query = f"SELECT {cols} FROM {table} WHERE {condition};"
71102
response = self.execute(get_row_query)
72-
values = response.fetchone()
73103

74-
if values is None:
104+
if response is None:
75105
return None
76106

77-
keys = map(lambda x: x[0], response.description)
78-
res = dict(zip(keys, values))
107+
res = response.fetchall()
108+
if len(res) == 1:
109+
return res[0]
110+
111+
if isinstance(res, list) and len(res) > 1:
112+
res = TraitsDB.merge_rows(res)
113+
79114
return res
80115

81116
def put_path_id(self, path):
@@ -194,7 +229,7 @@ def update_traits(self):
194229
ORDER BY name;
195230
"""
196231
traits = self.execute(get_traits_query).fetchall()
197-
self.traits = [x[0] for x in traits]
232+
self.traits = [list(x.values())[0] for x in traits]
198233
self.put_data_view()
199234

200235
def create_trait_table(self, trait_name, value_type):

test/test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,18 @@ def test_example(self):
3131
for k, v in target.items():
3232
self.assertEqual(source[k], v)
3333

34+
source = pathtraits.access.get_dict(db, "test/example/EU")
35+
target = {
36+
"description_TEXT": "EU data",
37+
"is_example_BOOL": True,
38+
"score_REAL": 3.5,
39+
"users_TEXT": ["dloos", "fgans"],
40+
}
41+
for k, v in target.items():
42+
self.assertEqual(source[k], v)
43+
3444
source = len(db.execute("SELECT * FROM data;").fetchall())
35-
target = 4
45+
target = 6
3646
self.assertEqual(source, target)
3747
os.remove(db_path)
3848

0 commit comments

Comments
 (0)