Skip to content

Commit 454c18a

Browse files
authored
fix: make unflatten_dict symmetric to flatten_dict (#461)
1 parent 624b6e7 commit 454c18a

File tree

4 files changed

+270
-90
lines changed

4 files changed

+270
-90
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Unreleased
44

55
- Add router option to LiteLLMEmbedder (#440)
6+
- Fix: make unflatten_dict symmetric to flatten_dict (#461)
67

78
## 0.12.0 (2025-03-25)
89
- Allow Prompt class to accept the asynchronous response_parser. Change the signature of parse_response method.

packages/ragbits-core/src/ragbits/core/utils/dict_transformations.py

Lines changed: 196 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,193 @@ def flatten_dict(input_dict: dict[str, Any], parent_key: str = "", sep: str = ".
4343
return items
4444

4545

46-
def unflatten_dict(input_dict: dict[str, Any]) -> dict[str, Any] | list:
46+
def _parse_key(key: str) -> list[tuple[str, bool]]:
47+
"""Parse a key into parts, each part being (name, is_array_index)."""
48+
parts = []
49+
current = ""
50+
i = 0
51+
while i < len(key):
52+
if key[i] == "[":
53+
if current:
54+
parts.append((current, False)) # Parent is not an array
55+
current = ""
56+
i += 1 # Skip [
57+
start = i
58+
while i < len(key) and key[i] != "]":
59+
i += 1
60+
parts.append((key[start:i], True))
61+
i += 1 # Skip ]
62+
if i < len(key) and key[i] == ".":
63+
i += 1 # Skip .
64+
elif key[i] == ".":
65+
if current:
66+
parts.append((current, False))
67+
current = ""
68+
i += 1
69+
else:
70+
current += key[i]
71+
i += 1
72+
if current:
73+
parts.append((current, False))
74+
return parts
75+
76+
77+
def _ensure_array(obj: dict[str, Any] | list[Any], key: str) -> list[Any]:
78+
"""Ensure that obj[key] is a list, creating it if necessary."""
79+
if isinstance(obj, list):
80+
return obj
81+
if key not in obj or not isinstance(obj[key], list):
82+
obj[key] = []
83+
return obj[key]
84+
85+
86+
def _ensure_dict(obj: dict[str, Any] | list[Any], key: str) -> dict[str, Any]:
87+
"""Ensure that obj[key] is a dict, creating it if necessary."""
88+
if isinstance(obj, list):
89+
# Lists should be handled by the caller
90+
raise TypeError("Cannot ensure dict in a list")
91+
if key not in obj or not isinstance(obj[key], dict):
92+
obj[key] = {}
93+
return obj[key]
94+
95+
96+
DictOrList = dict[str, Any] | list[Any]
97+
98+
99+
def _handle_array_part(
100+
current: DictOrList,
101+
part: str,
102+
parent_key: str | None = None,
103+
) -> DictOrList:
104+
"""Handle an array part in the key."""
105+
idx = int(part)
106+
if isinstance(current, list):
107+
while len(current) <= idx:
108+
current.append({})
109+
return current[idx]
110+
if parent_key is None:
111+
raise ValueError(f"Array part '{part}' without parent key")
112+
current_list = _ensure_array(current, parent_key)
113+
while len(current_list) <= idx:
114+
current_list.append({})
115+
return current_list[idx]
116+
117+
118+
def _handle_dict_part(
119+
current: DictOrList,
120+
part: str,
121+
next_is_array: bool,
122+
array_idx: int | None = None,
123+
) -> DictOrList:
124+
"""Handle a dictionary part in the key."""
125+
if isinstance(current, list):
126+
if array_idx is None:
127+
raise ValueError("Array index is required when current is a list")
128+
while len(current) <= array_idx:
129+
current.append({})
130+
current = current[array_idx]
131+
if not isinstance(current, dict):
132+
current = {}
133+
current[str(array_idx)] = current
134+
if next_is_array:
135+
return _ensure_array(current, part)
136+
return _ensure_dict(current, part)
137+
138+
139+
def _handle_single_part(
140+
new_dict: dict[str, Any],
141+
first_part: str,
142+
is_array: bool,
143+
value: SimpleTypes,
144+
) -> None:
145+
"""Handle a single-part key."""
146+
if is_array:
147+
idx = int(first_part)
148+
current = _ensure_array(new_dict, first_part)
149+
while len(current) <= idx:
150+
current.append(None)
151+
current[idx] = value
152+
else:
153+
new_dict[first_part] = value
154+
155+
156+
def _handle_last_array_part(
157+
current_obj: DictOrList,
158+
last_part: str,
159+
value: SimpleTypes,
160+
parts: list[tuple[str, bool]],
161+
) -> None:
162+
"""Handle the last part of the key when it's an array index."""
163+
idx = int(last_part)
164+
if len(parts) == 1:
165+
# Direct array access like "users[0]"
166+
parent_key = parts[0][0]
167+
current_obj = _ensure_array(current_obj, parent_key)
168+
if isinstance(current_obj, list):
169+
while len(current_obj) <= idx:
170+
current_obj.append(None)
171+
current_obj[idx] = value
172+
else:
173+
raise TypeError("Expected list but got dict")
174+
175+
176+
def _handle_last_dict_part(
177+
current_obj: DictOrList,
178+
last_part: str,
179+
value: SimpleTypes,
180+
parts: list[tuple[str, bool]],
181+
) -> None:
182+
"""Handle the last part of the key when it's a dictionary key."""
183+
if isinstance(current_obj, list):
184+
# We're in a list, so we need to ensure the current index has a dict
185+
idx = int(parts[-2][0]) # Get the index from the previous part
186+
while len(current_obj) <= idx:
187+
current_obj.append({})
188+
current_obj = current_obj[idx]
189+
if not isinstance(current_obj, dict):
190+
current_obj = {}
191+
current_obj[str(idx)] = current_obj
192+
if isinstance(current_obj, dict):
193+
current_obj[last_part] = value
194+
else:
195+
raise TypeError("Expected dict but got list")
196+
197+
198+
def _set_value(current: dict[str, Any], parts: list[tuple[str, bool]], value: SimpleTypes) -> None:
199+
"""Set a value in the dictionary based on the parsed key parts."""
200+
current_obj: DictOrList = current
201+
202+
# Handle all parts except the last one
203+
for i, (part, is_array) in enumerate(parts[:-1]):
204+
if is_array:
205+
current_obj = _handle_array_part(current_obj, part, parts[i - 1][0] if i > 0 else None)
206+
else:
207+
next_is_array = i + 1 < len(parts) and parts[i + 1][1]
208+
array_idx = int(parts[i][0]) if isinstance(current_obj, list) else None
209+
current_obj = _handle_dict_part(current_obj, part, next_is_array, array_idx)
210+
211+
# Handle the last part
212+
last_part, is_array = parts[-1]
213+
if is_array:
214+
_handle_last_array_part(current_obj, last_part, value, parts)
215+
else:
216+
_handle_last_dict_part(current_obj, last_part, value, parts)
217+
218+
219+
def unflatten_dict(input_dict: dict[str, Any]) -> dict[str, Any]:
47220
"""
48221
Converts a flattened dictionary with dot notation and array notation into a nested structure.
49222
50223
This function transforms a dictionary with flattened keys (using dot notation for nested objects
51-
and bracket notation for arrays) into a nested dictionary or list structure. It handles both
52-
object-like nesting (using dots) and array-like nesting (using brackets).
224+
and bracket notation for arrays) into a nested dictionary structure. It uses the notation to determine
225+
whether a value should be a dictionary or list.
53226
54227
Args:
55-
input_dict (dict[Any, Any]): A dictionary with flattened keys. Keys can use dot notation
228+
input_dict (dict[str, Any]): A dictionary with flattened keys. Keys can use dot notation
56229
(e.g., "person.name") or array notation (e.g., "addresses[0].street").
57230
58231
Returns:
59-
Union[dict[Any, Any], list]: A nested dictionary or list structure. Returns a list if all
60-
top-level keys are consecutive integer strings starting from 0.
232+
dict[str, Any]: A nested dictionary structure. Lists are created only when using array notation.
61233
62234
Examples:
63235
>>> unflatten_dict({"person.name": "John", "person.age": 30})
@@ -67,82 +239,29 @@ def unflatten_dict(input_dict: dict[str, Any]) -> dict[str, Any] | list:
67239
{'addresses': [{'street': 'Main St'}, {'street': 'Broadway'}]}
68240
69241
>>> unflatten_dict({"0": "first", "1": "second"})
70-
['first', 'second']
71-
72-
Notes:
73-
- The function recursively processes nested structures
74-
- If all keys at any level are consecutive integers starting from 0, that level will be
75-
converted to a list
76-
- The function preserves the original values for non-nested keys
77-
- Keys are sorted before processing to ensure consistent results
78-
79-
Attribution:
80-
- This function is based on the answer by user "djtubig-malicex" on Stack Overflow: https://stackoverflow.com/a/67905359/27947364
242+
{'0': 'first', '1': 'second'}
81243
"""
82244
if not input_dict:
83245
return {}
84246

85-
new_dict: dict[Any, Any] = {}
86-
field_keys = sorted(input_dict.keys())
87-
88-
def _decompose_key(key: str) -> tuple[str | int | None, str | int | None]:
89-
_key = str(key)
90-
_current_key: str | int | None = None
91-
_current_subkey: str | int | None = None
92-
93-
for idx, char in enumerate(_key):
94-
if char == "[":
95-
_current_key = _key[:idx]
96-
start_subscript_index = idx + 1
97-
end_subscript_index = _key.index("]")
98-
_current_subkey = int(_key[start_subscript_index:end_subscript_index])
99-
100-
if len(_key[end_subscript_index:]) > 1:
101-
_current_subkey = f"{_current_subkey}.{_key[end_subscript_index + 2 :]}"
102-
break
103-
elif char == ".":
104-
split_work = _key.split(".", 1)
105-
if len(split_work) > 1:
106-
_current_key, _current_subkey = split_work
107-
else:
108-
_current_key = split_work[0]
109-
break
247+
new_dict: dict[str, Any] = {}
110248

111-
return _current_key, _current_subkey
112-
113-
for each_key in field_keys:
114-
field_value = input_dict[each_key]
115-
current_key, current_subkey = _decompose_key(each_key)
116-
117-
if current_key is not None and current_subkey is not None:
118-
if isinstance(current_key, str) and current_key.isdigit():
119-
current_key = int(current_key)
120-
if current_key not in new_dict:
121-
new_dict[current_key] = {}
122-
new_dict[current_key][current_subkey] = field_value
249+
# Sort keys to ensure we process parents before children
250+
field_keys = sorted(input_dict.keys())
251+
for key in field_keys:
252+
parts = _parse_key(key)
253+
if not parts:
254+
continue
255+
256+
# Handle the first part specially to ensure it's created in new_dict
257+
first_part, is_array = parts[0]
258+
if first_part not in new_dict:
259+
new_dict[first_part] = {} if not is_array else []
260+
261+
# Set the value
262+
if len(parts) == 1:
263+
_handle_single_part(new_dict, first_part, is_array, input_dict[key])
123264
else:
124-
new_dict[each_key] = field_value
125-
126-
all_digits = True
127-
highest_digit = -1
128-
129-
for each_key, each_item in new_dict.items():
130-
if isinstance(each_item, dict):
131-
new_dict[each_key] = unflatten_dict(each_item)
132-
133-
all_digits &= str(each_key).isdigit()
134-
if all_digits:
135-
next_digit = int(each_key)
136-
highest_digit = max(next_digit, highest_digit)
137-
138-
if all_digits and highest_digit == (len(new_dict) - 1):
139-
digit_keys = sorted(new_dict.keys(), key=int)
140-
new_list: list = [None] * (highest_digit + 1)
141-
142-
for k in digit_keys:
143-
i = int(k)
144-
new_list[i] = new_dict[k]
145-
146-
return new_list
265+
_set_value(new_dict, parts, input_dict[key])
147266

148267
return new_dict

packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ async def retrieve(
184184

185185
metadatas: Sequence = [dict(metadata) for batch in results.get("metadatas") or [] for metadata in batch]
186186

187-
# Remove the `# type: ignore` comment when https://github.com/deepsense-ai/ragbits/pull/379/files resolved
188-
unflattened_metadatas: list[dict] = [unflatten_dict(metadata) if metadata else {} for metadata in metadatas] # type: ignore[misc]
187+
# Convert metadata back to nested structure
188+
unflattened_metadatas: list[dict] = [unflatten_dict(metadata) if metadata else {} for metadata in metadatas]
189189

190190
images: list[bytes | None] = [metadata.pop("__image", None) for metadata in unflattened_metadatas]
191191

@@ -253,8 +253,8 @@ async def list(
253253
documents = results.get("documents") or []
254254
metadatas: Sequence = results.get("metadatas") or []
255255

256-
# Remove the `# type: ignore` comment when https://github.com/deepsense-ai/ragbits/pull/379/files resolved
257-
unflattened_metadatas: list[dict] = [unflatten_dict(metadata) if metadata else {} for metadata in metadatas] # type: ignore[misc]
256+
# Convert metadata back to nested structure
257+
unflattened_metadatas: list[dict] = [unflatten_dict(metadata) if metadata else {} for metadata in metadatas]
258258

259259
images: list[bytes | None] = [metadata.pop("__image", None) for metadata in unflattened_metadatas]
260260

0 commit comments

Comments
 (0)