|
| 1 | +# Databricks notebook source |
| 2 | +# MAGIC %pip install trend-narrative |
| 3 | + |
| 4 | +# COMMAND ---------- |
| 5 | + |
| 6 | +from trend_narrative import InsightExtractor |
| 7 | +import pandas as pd |
| 8 | + |
| 9 | +CATALOG = "prd_mega" |
| 10 | +SCHEMA = "boost" |
| 11 | +TABLE_NAME = "expenditure_by_country_func_econ_year" |
| 12 | +START_YEAR = 2010 |
| 13 | + |
| 14 | +INSIGHT_CONFIGS = [ |
| 15 | + { |
| 16 | + "dimension": "func", |
| 17 | + "dimension_filter": "Health", |
| 18 | + "metric": "real_expenditure", |
| 19 | + "metric_name": "real expenditure", |
| 20 | + }, |
| 21 | + { |
| 22 | + "dimension": "func", |
| 23 | + "dimension_filter": "Education", |
| 24 | + "metric": "real_expenditure", |
| 25 | + "metric_name": "real expenditure", |
| 26 | + }, |
| 27 | + { |
| 28 | + "dimension": None, |
| 29 | + "dimension_filter": None, |
| 30 | + "metric": "real_expenditure", |
| 31 | + "metric_name": "total real expenditure", |
| 32 | + }, |
| 33 | +] |
| 34 | +MIN_DATA_POINTS = 4 |
| 35 | + |
| 36 | +# COMMAND ---------- |
| 37 | + |
| 38 | +def process_country(pdf: pd.DataFrame) -> pd.DataFrame: |
| 39 | + """Process all insight configs for a single country.""" |
| 40 | + country = pdf["country_name"].iloc[0] |
| 41 | + # Get a fresh "base" slice for the country |
| 42 | + country_base_df = pdf[pdf.year >= START_YEAR] |
| 43 | + insights = [] |
| 44 | + |
| 45 | + for config in INSIGHT_CONFIGS: |
| 46 | + metric = config["metric"] |
| 47 | + dim = config["dimension"] |
| 48 | + dim_val = config["dimension_filter"] |
| 49 | + |
| 50 | + # Start fresh with the country's data for this specific config |
| 51 | + df_temp = country_base_df.copy().dropna(subset=[metric]) |
| 52 | + |
| 53 | + # Apply dimension filter if it exists (e.g., Health) |
| 54 | + if dim and dim_val: |
| 55 | + df_temp = df_temp[df_temp[dim] == dim_val] |
| 56 | + |
| 57 | + # Aggregate to Year level (summing up expenditures) |
| 58 | + df_plot = df_temp.groupby("year")[metric].sum().reset_index() |
| 59 | + df_plot = df_plot.sort_values(by="year") |
| 60 | + |
| 61 | + if len(df_plot) >= MIN_DATA_POINTS: |
| 62 | + X = df_plot["year"].values |
| 63 | + Y = df_plot[metric].values |
| 64 | + |
| 65 | + # Pure math extraction |
| 66 | + extractor = InsightExtractor(X, Y) |
| 67 | + result = extractor.extract_full_suite() |
| 68 | + |
| 69 | + # Explicit Assignment (Metadata) |
| 70 | + result.update( |
| 71 | + { |
| 72 | + "country_name": country, |
| 73 | + "metric": metric, |
| 74 | + "metric_name": config["metric_name"], |
| 75 | + "dimension": dim if dim else "Total", |
| 76 | + "dimension_filter": dim_val if dim_val else "Total", |
| 77 | + "table_name": TABLE_NAME, |
| 78 | + } |
| 79 | + ) |
| 80 | + insights.append(result) |
| 81 | + |
| 82 | + return pd.DataFrame(insights) |
| 83 | + |
| 84 | + |
| 85 | +# COMMAND ---------- |
| 86 | + |
| 87 | +source_df = spark.table(f"{CATALOG}.{SCHEMA}.{TABLE_NAME}") |
| 88 | + |
| 89 | +# Infer schema from a sample country that produces output |
| 90 | +sample_countries = source_df.select("country_name").distinct().limit(10).collect() |
| 91 | +output_schema = None |
| 92 | +for row in sample_countries: |
| 93 | + sample_pdf = source_df.filter(source_df.country_name == row[0]).toPandas() |
| 94 | + sample_output = process_country(sample_pdf) |
| 95 | + if not sample_output.empty: |
| 96 | + output_schema = spark.createDataFrame(sample_output).schema |
| 97 | + break |
| 98 | +if output_schema is None: |
| 99 | + raise ValueError("No sample country produced insights - cannot infer schema") |
| 100 | + |
| 101 | +insights_df = source_df.groupBy("country_name").applyInPandas(process_country, schema=output_schema) |
| 102 | + |
| 103 | +# COMMAND ---------- |
| 104 | + |
| 105 | +INSIGHT_TABLE_NAME = "expenditure_insights" |
| 106 | +insights_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(f"{CATALOG}.{SCHEMA}.{INSIGHT_TABLE_NAME}") |
0 commit comments