Skip to content

Commit cc4b9af

Browse files
committed
Allow spark to parallelize the processing across countries
TODO: test this on databricks
1 parent 6f33c86 commit cc4b9af

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

analytics/extract_insight.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,14 @@
3333
]
3434
MIN_DATA_POINTS = 4
3535

36-
source_table = spark.table(f"{CATALOG}.{SCHEMA}.{TABLE_NAME}").toPandas()
37-
countries = source_table.country_name.unique()
38-
39-
insights = []
36+
# COMMAND ----------
4037

41-
for country in countries:
38+
def process_country(pdf: pd.DataFrame) -> pd.DataFrame:
39+
"""Process all insight configs for a single country."""
40+
country = pdf["country_name"].iloc[0]
4241
# Get a fresh "base" slice for the country
43-
country_base_df = source_table[
44-
(source_table.country_name == country) & (source_table.year >= START_YEAR)
45-
]
42+
country_base_df = pdf[pdf.year >= START_YEAR]
43+
insights = []
4644

4745
for config in INSIGHT_CONFIGS:
4846
metric = config["metric"]
@@ -81,9 +79,28 @@
8179
)
8280
insights.append(result)
8381

82+
return pd.DataFrame(insights)
83+
84+
85+
# COMMAND ----------
86+
87+
source_df = spark.table(f"{CATALOG}.{SCHEMA}.{TABLE_NAME}")
88+
89+
# Infer schema from a sample country that produces output
90+
sample_countries = source_df.select("country_name").distinct().limit(10).collect()
91+
output_schema = None
92+
for row in sample_countries:
93+
sample_pdf = source_df.filter(source_df.country_name == row[0]).toPandas()
94+
sample_output = process_country(sample_pdf)
95+
if not sample_output.empty:
96+
output_schema = spark.createDataFrame(sample_output).schema
97+
break
98+
if output_schema is None:
99+
raise ValueError("No sample country produced insights - cannot infer schema")
100+
101+
insights_df = source_df.groupBy("country_name").applyInPandas(process_country, schema=output_schema)
102+
84103
# COMMAND ----------
85104

86-
INSIGHT_TABLE_NAME = 'expenditure_insights'
87-
insights_df = pd.DataFrame(insights)
88-
sdf = spark.createDataFrame(insights_df)
89-
sdf.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(f"{CATALOG}.{SCHEMA}.{INSIGHT_TABLE_NAME}")
105+
INSIGHT_TABLE_NAME = "expenditure_insights"
106+
insights_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(f"{CATALOG}.{SCHEMA}.{INSIGHT_TABLE_NAME}")

0 commit comments

Comments
 (0)