Skip to content

Commit 1fda584

Browse files
authored
Merge pull request #9 from danlooo/add-nesting-keys
Add nesting keys
2 parents 855c512 + d5c3189 commit 1fda584

File tree

4 files changed

+73
-15
lines changed

4 files changed

+73
-15
lines changed

pathtraits/access.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,32 @@
1111
logger = logging.getLogger(__name__)
1212

1313

14+
def nest_dict(flat_dict, delimiter="/"):
15+
"""
16+
Transforms a flat dictionary with path-like keys into a nested dictionary.
17+
18+
:param flat_dict: The flat dictionary with path-like keys.
19+
:param delimiter: The delimiter used in the keys (default is '/').
20+
:return: A nested dictionary.
21+
"""
22+
nested_dict = {}
23+
24+
for path, value in flat_dict.items():
25+
keys = path.split(delimiter)
26+
current = nested_dict
27+
28+
for key in keys[:-1]:
29+
# If the key doesn't exist or is not a dictionary, create/overwrite it as a dictionary
30+
if key not in current or not isinstance(current[key], dict):
31+
current[key] = {}
32+
current = current[key]
33+
34+
# Set the value at the final key
35+
current[keys[-1]] = value
36+
37+
return nested_dict
38+
39+
1440
def get_dict(self, path):
1541
"""
1642
Get traits for a path as a Python dictionary
@@ -40,6 +66,7 @@ def get_dict(self, path):
4066
if not (v and k != "path"):
4167
continue
4268
res[k] = v
69+
res = nest_dict(res)
4370
return res
4471

4572

pathtraits/db.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import sqlite3
77
import os
8+
from collections.abc import MutableMapping
89
import yaml
910
from pathtraits.pathpair import PathPair
1011

@@ -33,12 +34,12 @@ def row_factory(cursor, row):
3334
if v is None:
3435
continue
3536
# sqlite don't know bool
36-
if k.endswith("_BOOL"):
37+
if k.endswith("/BOOL"):
3738
v = v > 0
3839
if isinstance(v, float):
3940
v_int = int(v)
4041
v = v_int if v_int == v else v
41-
k = k.removesuffix("_TEXT").removesuffix("_REAL").removesuffix("_BOOL")
42+
k = k.removesuffix("/TEXT").removesuffix("/REAL").removesuffix("/BOOL")
4243
res[k] = v
4344
return res
4445

@@ -53,15 +54,35 @@ def merge_rows(rows: list):
5354
for row in rows:
5455
for k, v in row.items():
5556
# pylint: disable=C0201
56-
if k in res.keys() and v not in res[k]:
57+
if not k in res.keys():
58+
res[k] = []
59+
if not v in res[k]:
5760
res[k].append(v)
58-
else:
59-
res[k] = [v]
61+
6062
# simplify lists with just one element
6163
# ensure fixed order of list entries
6264
res = {k: sorted(v, key=str) if len(v) > 1 else v[0] for k, v in res.items()}
6365
return res
6466

67+
@staticmethod
68+
def flatten_dict(dictionary: dict, root_key: str = "", separator: str = "/"):
69+
"""
70+
Docstring for flatten_dict
71+
72+
:param d: Description
73+
:type d: dict
74+
"""
75+
items = []
76+
for key, value in dictionary.items():
77+
new_key = root_key + separator + key if root_key else key
78+
if isinstance(value, MutableMapping):
79+
items.extend(
80+
TraitsDB.flatten_dict(value, new_key, separator=separator).items()
81+
)
82+
else:
83+
items.append((new_key, value))
84+
return dict(items)
85+
6586
def __init__(self, db_path):
6687
db_path = os.path.join(db_path)
6788
self.cursor = sqlite3.connect(db_path, autocommit=True).cursor()
@@ -189,15 +210,15 @@ def put(self, table, condition=None, update=True, **kwargs):
189210
# update
190211
values = " , ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()])
191212
if condition:
192-
update_query = f"UPDATE {table} SET {values} WHERE {condition};"
213+
update_query = f"UPDATE [{table}] SET {values} WHERE {condition};"
193214
else:
194-
update_query = f"UPDATE {table} SET {values};"
215+
update_query = f"UPDATE [{table}] SET {values};"
195216
self.execute(update_query)
196217
else:
197218
# insert
198-
keys = " , ".join(escaped_kwargs.keys())
219+
keys = "[" + "], [".join(escaped_kwargs.keys()) + "]"
199220
values = " , ".join([str(x) for x in escaped_kwargs.values()])
200-
insert_query = f"INSERT INTO {table} ({keys}) VALUES ({values});"
221+
insert_query = f"INSERT INTO [{table}] ({keys}) VALUES ({values});"
201222
self.execute(insert_query)
202223

203224
def put_data_view(self):
@@ -209,15 +230,15 @@ def put_data_view(self):
209230
if self.traits:
210231
join_query = " ".join(
211232
[
212-
f"LEFT JOIN {x} ON {x}.path = path.id"
233+
f"LEFT JOIN [{x}] ON [{x}].path = path.id \n"
213234
for x in self.traits
214235
if x != "path"
215236
]
216237
)
217238

218239
create_view_query = f"""
219240
CREATE VIEW data AS
220-
SELECT path.path, {', '.join(self.traits)}
241+
SELECT path.path, [{'], ['.join(self.traits)}]
221242
FROM path
222243
{join_query};
223244
"""
@@ -263,9 +284,9 @@ def create_trait_table(self, trait_name, value_type):
263284
return
264285
sql_type = TraitsDB.sql_type(value_type)
265286
add_table_query = f"""
266-
CREATE TABLE {trait_name} (
287+
CREATE TABLE [{trait_name}] (
267288
path INTEGER,
268-
{trait_name} {sql_type},
289+
[{trait_name}] {sql_type},
269290
FOREIGN KEY(path) REFERENCES path(id)
270291
);
271292
"""
@@ -303,6 +324,8 @@ def add_pathpair(self, pair: PathPair):
303324
if not isinstance(traits, dict):
304325
return
305326

327+
traits = TraitsDB.flatten_dict(traits)
328+
306329
# put path in db only if there are traits
307330
path_id = self.put_path_id(os.path.abspath(pair.object_path))
308331
for k, v in traits.items():
@@ -312,13 +335,15 @@ def add_pathpair(self, pair: PathPair):
312335
# get element type for list
313336
# add: handle lists with mixed element type
314337
t = type(v[0]) if isinstance(v, list) else type(v)
315-
k = f"{k}_{TraitsDB.sql_type(t)}"
338+
k = f"{k}/{TraitsDB.sql_type(t)}"
316339
if k not in self.traits:
317340
t = type(v[0]) if isinstance(v, list) else type(v)
318341
self.create_trait_table(k, t)
319342
if k in self.traits:
343+
# add to list
320344
if isinstance(v, list):
321345
for vv in v:
322346
self.put_trait(path_id, k, vv, update=False)
347+
# overwrite scalar
323348
else:
324349
self.put_trait(path_id, k, v)

test/example/EU/meta.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,8 @@ users:
33
- dloos
44
- fgans
55
score: 3.5
6+
foo:
7+
bar:
8+
a: 1
9+
b: 2
10+
c: [1, 2, 3]

test/test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_eu(self):
4747
"is_example": True,
4848
"score": 3.5,
4949
"users": ["dloos", "fgans"],
50+
"foo": {"bar": {"a": 1, "b": 2, "c": [1, 2, 3]}},
5051
}
5152
for k, v in target.items():
5253
self.assertEqual(source[k], v)
@@ -63,7 +64,7 @@ def test_example(self):
6364

6465
def test_data_view(self):
6566
source = len(self.db.execute("SELECT * FROM data;").fetchall())
66-
target = 6
67+
target = 10
6768
self.assertEqual(source, target)
6869

6970

0 commit comments

Comments
 (0)