Skip to content

Commit 509fa62

Browse files
committed
NRL-1268 update flatten func to traverse schema for all nested structs
1 parent c4e4698 commit 509fa62

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

terraform/account-wide-infrastructure/modules/glue/src/transformations.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pyspark.sql.functions import to_timestamp
1+
from pyspark.sql.functions import col, to_timestamp
22
from pyspark.sql.types import (
33
BooleanType,
44
StringType,
@@ -60,13 +60,22 @@
6060

6161

6262
def flatten_df(df):
63-
cols = []
64-
for c in df.dtypes:
65-
if "struct" in c[1]:
66-
nested_col = c[0]
67-
else:
68-
cols.append(c[0])
69-
return df.select(*cols, f"{nested_col}.*")
63+
def flatten(schema, prefix=""):
64+
"""
65+
Recursively traverse the schema to extract all nested fields.
66+
"""
67+
fields = []
68+
for field in schema.fields:
69+
name = f"{prefix}.{field.name}" if prefix else field.name
70+
if isinstance(field.dataType, StructType):
71+
fields += flatten(field.dataType, name)
72+
else:
73+
fields.append((name, field.name))
74+
return fields
75+
76+
flat_columns = flatten(df.schema)
77+
78+
return df.select([col(c).alias(n) for c, n in flat_columns])
7079

7180

7281
def dtype_conversion(df):

0 commit comments

Comments
 (0)