Skip to content

Commit f802164

Browse files
committed
Expand nested MongoDB fields before writing Parquet
1 parent 1c02580 commit f802164

File tree

1 file changed

+54
-2
lines changed

1 file changed

+54
-2
lines changed

scripts/migrate_mongo_to_parquet.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections import defaultdict
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import DefaultDict, Dict, Iterable, List
12+
from typing import Any, DefaultDict, Dict, Iterable, List
1313

1414
import click
1515
import pandas as pd
@@ -40,10 +40,62 @@ def path(self) -> Path:
4040
return self.partition_dir / self.filename
4141

4242

43+
def _is_expandable(value: Any) -> bool:
44+
"""Return ``True`` if *value* should be expanded into scalar columns."""
45+
46+
return isinstance(value, (dict, list, tuple))
47+
48+
49+
def _flatten_nested(value: Any) -> Any:
50+
"""Recursively convert nested *value* into a dict keyed by indices."""
51+
52+
if isinstance(value, dict):
53+
return {key: _flatten_nested(val) for key, val in value.items()}
54+
55+
if isinstance(value, (list, tuple)):
56+
return {str(idx): _flatten_nested(val) for idx, val in enumerate(value)}
57+
58+
return value
59+
60+
61+
def _expand_nested_columns(frame: pd.DataFrame) -> pd.DataFrame:
62+
"""Expand list- or dict-typed columns in *frame* into scalar columns."""
63+
64+
for column in list(frame.columns):
65+
series = frame[column]
66+
mask = series.apply(_is_expandable)
67+
68+
if not mask.any():
69+
continue
70+
71+
prepared_rows = [
72+
_flatten_nested(value) if expand else {}
73+
for value, expand in zip(series.tolist(), mask.tolist())
74+
]
75+
expanded = pd.json_normalize(prepared_rows, sep=".")
76+
77+
if not expanded.empty:
78+
expanded.index = series.index
79+
expanded = expanded.add_prefix(f"{column}.")
80+
frame = frame.join(expanded)
81+
82+
if mask.all():
83+
frame = frame.drop(columns=[column])
84+
else:
85+
frame.loc[mask, column] = None
86+
87+
return frame
88+
89+
4390
def _normalize_records(records: Iterable[Dict]) -> pd.DataFrame:
4491
"""Return a flattened dataframe for *records*."""
4592

46-
return pd.json_normalize(list(records), sep=".")
93+
frame = pd.json_normalize(list(records), sep=".")
94+
95+
if frame.empty:
96+
return frame
97+
98+
return _expand_nested_columns(frame)
4799

48100

49101
def _partition_target(

0 commit comments

Comments
 (0)