Skip to content

Commit 97e5bc9

Browse files
committed
first stab at demo of optimization
1 parent a058d93 commit 97e5bc9

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed

docs/source/notebooks/graded_intervention_time_series_single_channel_ols.ipynb

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,235 @@
17621762
"**Bottom Line:** HAC is the recommended default for most applications. Use ARIMAX only when you have strong evidence for a specific ARIMA structure and are comfortable with the added complexity and assumptions.\n"
17631763
]
17641764
},
1765+
{
1766+
"cell_type": "markdown",
1767+
"metadata": {},
1768+
"source": [
1769+
"## Continuous Optimization for Parameter Estimation\n",
1770+
"\n",
1771+
"So far we've used **grid search** (`estimation_method=\"grid\"`) to estimate transform parameters by evaluating discrete parameter combinations. CausalPy also supports **continuous optimization** (`estimation_method=\"optimize\"`) which can explore the full continuous parameter space using gradient-based methods.\n",
1772+
"\n",
1773+
"**Advantages of optimization:**\n",
1774+
"- Explores continuous parameter space (not limited to grid points)\n",
1775+
"- Can find more precise parameter estimates\n",
1776+
"- Often faster for fine-grained search (doesn't evaluate all combinations)\n",
1777+
"- Better suited when you have good initial guesses\n",
1778+
"\n",
1779+
"**Tradeoffs:**\n",
1780+
"- May converge to local optima (depends on starting point)\n",
1781+
"- Less exhaustive than grid search (might miss global optimum if poorly initialized)\n",
1782+
"- Uses scipy.optimize.minimize with L-BFGS-B method\n",
1783+
"\n",
1784+
"We'll demonstrate optimization using the ARIMAX error model and compare parameter recovery against grid search.\n"
1785+
]
1786+
},
1787+
{
1788+
"cell_type": "code",
1789+
"execution_count": null,
1790+
"metadata": {},
1791+
"outputs": [],
1792+
"source": [
1793+
"model_arimax_opt = cp.skl_models.TransferFunctionOLS(\n",
1794+
" saturation_type=None, # No saturation - adstock only\n",
1795+
" adstock_bounds={\n",
1796+
" \"half_life\": (0.5, 3.0), # Continuous range (same as grid: 0.5 to 3.0)\n",
1797+
" },\n",
1798+
" estimation_method=\"optimize\", # Continuous optimization\n",
1799+
" error_model=\"arimax\",\n",
1800+
" arima_order=(1, 0, 0),\n",
1801+
")\n",
1802+
"\n",
1803+
"result_arimax_opt = cp.GradedInterventionTimeSeries(\n",
1804+
" data=df,\n",
1805+
" y_column=\"water_consumption\",\n",
1806+
" treatment_names=[\"comm_intensity\"],\n",
1807+
" base_formula=\"1 + t + temperature + rainfall\",\n",
1808+
" model=model_arimax_opt,\n",
1809+
")\n",
1810+
"\n",
1811+
"print(\"Optimization complete!\")\n",
1812+
"print(f\"Best RMSE: {result_arimax_opt.transform_estimation_results['best_score']:.2f}\")\n",
1813+
"print(\n",
1814+
" f\"Estimated parameters: {result_arimax_opt.transform_estimation_results['best_params']}\"\n",
1815+
")"
1816+
]
1817+
},
1818+
{
1819+
"cell_type": "markdown",
1820+
"metadata": {},
1821+
"source": [
1822+
"### Compare Transform Parameter Recovery\n"
1823+
]
1824+
},
1825+
{
1826+
"cell_type": "code",
1827+
"execution_count": null,
1828+
"metadata": {},
1829+
"outputs": [],
1830+
"source": [
1831+
"# Extract estimated parameters\n",
1832+
"# Grid search\n",
1833+
"half_life_grid = result_arimax.transform_estimation_results[\"best_params\"][\"half_life\"]\n",
1834+
"rmse_grid = result_arimax.transform_estimation_results[\"best_score\"]\n",
1835+
"\n",
1836+
"# Optimization\n",
1837+
"half_life_opt = result_arimax_opt.transform_estimation_results[\"best_params\"][\n",
1838+
" \"half_life\"\n",
1839+
"]\n",
1840+
"rmse_opt = result_arimax_opt.transform_estimation_results[\"best_score\"]\n",
1841+
"\n",
1842+
"# True value\n",
1843+
"half_life_true = 1.5\n",
1844+
"\n",
1845+
"# Create comparison table\n",
1846+
"comparison_data = {\n",
1847+
" \"Method\": [\"True Value\", \"ARIMAX Grid\", \"ARIMAX Optimize\"],\n",
1848+
" \"Half-life\": [\n",
1849+
" f\"{half_life_true:.3f}\",\n",
1850+
" f\"{half_life_grid:.3f}\",\n",
1851+
" f\"{half_life_opt:.3f}\",\n",
1852+
" ],\n",
1853+
" \"Error\": [\n",
1854+
" \"-\",\n",
1855+
" f\"{abs(half_life_grid - half_life_true):.3f}\",\n",
1856+
" f\"{abs(half_life_opt - half_life_true):.3f}\",\n",
1857+
" ],\n",
1858+
" \"RMSE\": [\"-\", f\"{rmse_grid:.2f}\", f\"{rmse_opt:.2f}\"],\n",
1859+
"}\n",
1860+
"\n",
1861+
"param_comparison_df = pd.DataFrame(comparison_data)\n",
1862+
"\n",
1863+
"print(\"=\" * 70)\n",
1864+
"print(\"PARAMETER RECOVERY: GRID vs OPTIMIZATION\")\n",
1865+
"print(\"=\" * 70)\n",
1866+
"print(param_comparison_df.to_string(index=False))\n",
1867+
"print(\"=\" * 70)\n",
1868+
"print()\n",
1869+
"print(\"KEY OBSERVATIONS:\")\n",
1870+
"print(f\"• True half-life: {half_life_true:.3f} weeks\")\n",
1871+
"print(\n",
1872+
" f\"• Grid search estimate: {half_life_grid:.3f} (error: {abs(half_life_grid - half_life_true):.3f})\"\n",
1873+
")\n",
1874+
"print(\n",
1875+
" f\"• Optimization estimate: {half_life_opt:.3f} (error: {abs(half_life_opt - half_life_true):.3f})\"\n",
1876+
")\n",
1877+
"print(\n",
1878+
" f\"• RMSE improvement: {rmse_grid - rmse_opt:.2f} ({(1 - rmse_opt / rmse_grid) * 100:.2f}%)\"\n",
1879+
")\n",
1880+
"if abs(half_life_opt - half_life_true) < abs(half_life_grid - half_life_true):\n",
1881+
" print(\"✓ Optimization achieved better parameter recovery\")\n",
1882+
"else:\n",
1883+
" print(\"• Grid search achieved comparable or better parameter recovery\")"
1884+
]
1885+
},
1886+
{
1887+
"cell_type": "code",
1888+
"execution_count": null,
1889+
"metadata": {},
1890+
"outputs": [],
1891+
"source": [
1892+
"# Visualize adstock function comparison\n",
1893+
"fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
1894+
"\n",
1895+
"# Get adstock objects\n",
1896+
"adstock_true = result_arimax.treatments[0].adstock # Just to get structure\n",
1897+
"adstock_grid = result_arimax.treatments[0].adstock\n",
1898+
"adstock_opt = result_arimax_opt.treatments[0].adstock\n",
1899+
"\n",
1900+
"# Calculate weights\n",
1901+
"l_max = 8\n",
1902+
"lags = np.arange(l_max + 1)\n",
1903+
"\n",
1904+
"# True weights\n",
1905+
"alpha_true = np.power(0.5, 1 / half_life_true)\n",
1906+
"weights_true = alpha_true**lags\n",
1907+
"weights_true = weights_true / weights_true.sum()\n",
1908+
"\n",
1909+
"# Grid weights\n",
1910+
"alpha_grid = np.power(0.5, 1 / half_life_grid)\n",
1911+
"weights_grid = alpha_grid**lags\n",
1912+
"weights_grid = weights_grid / weights_grid.sum()\n",
1913+
"\n",
1914+
"# Optimize weights\n",
1915+
"alpha_opt = np.power(0.5, 1 / half_life_opt)\n",
1916+
"weights_opt = alpha_opt**lags\n",
1917+
"weights_opt = weights_opt / weights_opt.sum()\n",
1918+
"\n",
1919+
"# Plot\n",
1920+
"width = 0.25\n",
1921+
"ax.bar(\n",
1922+
" lags - width,\n",
1923+
" weights_true,\n",
1924+
" width,\n",
1925+
" alpha=0.8,\n",
1926+
" label=f\"True (half-life={half_life_true:.2f})\",\n",
1927+
" color=\"black\",\n",
1928+
")\n",
1929+
"ax.bar(\n",
1930+
" lags,\n",
1931+
" weights_grid,\n",
1932+
" width,\n",
1933+
" alpha=0.8,\n",
1934+
" label=f\"Grid (half-life={half_life_grid:.2f})\",\n",
1935+
" color=\"C0\",\n",
1936+
")\n",
1937+
"ax.bar(\n",
1938+
" lags + width,\n",
1939+
" weights_opt,\n",
1940+
" width,\n",
1941+
" alpha=0.8,\n",
1942+
" label=f\"Optimize (half-life={half_life_opt:.2f})\",\n",
1943+
" color=\"C2\",\n",
1944+
")\n",
1945+
"\n",
1946+
"ax.set_xlabel(\"Lag (periods)\", fontsize=11)\n",
1947+
"ax.set_ylabel(\"Adstock Weight\", fontsize=11)\n",
1948+
"ax.set_title(\n",
1949+
" \"Adstock Parameter Recovery: Grid vs Optimization\", fontsize=12, fontweight=\"bold\"\n",
1950+
")\n",
1951+
"ax.legend(fontsize=10, framealpha=0.9)\n",
1952+
"ax.grid(True, alpha=0.3, axis=\"y\")\n",
1953+
"\n",
1954+
"plt.tight_layout()\n",
1955+
"plt.show()\n",
1956+
"\n",
1957+
"print(\"\\n📊 INTERPRETATION:\")\n",
1958+
"print(\n",
1959+
" \"Continuous optimization can find parameter values between grid points, potentially\"\n",
1960+
")\n",
1961+
"print(\n",
1962+
" \"achieving better fit (lower RMSE) and more accurate parameter recovery. The tradeoff\"\n",
1963+
")\n",
1964+
"print(\n",
1965+
" \"is that optimization may find local optima, while grid search exhaustively evaluates\"\n",
1966+
")\n",
1967+
"print(\n",
1968+
" \"all specified combinations. For this example, optimization explores the continuous\"\n",
1969+
")\n",
1970+
"print(\"range [0.5, 3.0] rather than being limited to 30 discrete grid points.\")"
1971+
]
1972+
},
1973+
{
1974+
"cell_type": "markdown",
1975+
"metadata": {},
1976+
"source": [
1977+
"### Summary: Grid Search vs Optimization\n",
1978+
"\n",
1979+
"**When to use grid search:**\n",
1980+
"- You want exhaustive evaluation of discrete parameter combinations\n",
1981+
"- Parameter space is small enough to evaluate densely\n",
1982+
"- You want to visualize the full search landscape\n",
1983+
"- Robustness to local optima is critical\n",
1984+
"\n",
1985+
"**When to use optimization:**\n",
1986+
"- You want fine-grained continuous parameter estimates\n",
1987+
"- Parameter space is large (many parameters or wide ranges)\n",
1988+
"- You have good intuition for reasonable parameter ranges\n",
1989+
"- Computational efficiency matters for large datasets\n",
1990+
"\n",
1991+
"**Best practice:** Start with coarse grid search to understand the landscape, then use optimization to refine estimates if needed.\n"
1992+
]
1993+
},
17651994
{
17661995
"cell_type": "markdown",
17671996
"metadata": {},

0 commit comments

Comments
 (0)