|
1762 | 1762 | "**Bottom Line:** HAC is the recommended default for most applications. Use ARIMAX only when you have strong evidence for a specific ARIMA structure and are comfortable with the added complexity and assumptions.\n" |
1763 | 1763 | ] |
1764 | 1764 | }, |
| 1765 | + { |
| 1766 | + "cell_type": "markdown", |
| 1767 | + "metadata": {}, |
| 1768 | + "source": [ |
| 1769 | + "## Continuous Optimization for Parameter Estimation\n", |
| 1770 | + "\n", |
| 1771 | + "So far we've used **grid search** (`estimation_method=\"grid\"`) to estimate transform parameters by evaluating discrete parameter combinations. CausalPy also supports **continuous optimization** (`estimation_method=\"optimize\"`) which can explore the full continuous parameter space using gradient-based methods.\n", |
| 1772 | + "\n", |
| 1773 | + "**Advantages of optimization:**\n", |
| 1774 | + "- Explores continuous parameter space (not limited to grid points)\n", |
| 1775 | + "- Can find more precise parameter estimates\n", |
| 1776 | + "- Often faster for fine-grained search (doesn't evaluate all combinations)\n", |
| 1777 | + "- Better suited when you have good initial guesses\n", |
| 1778 | + "\n", |
| 1779 | + "**Tradeoffs:**\n", |
| 1780 | + "- May converge to local optima (depends on starting point)\n", |
| 1781 | + "- Less exhaustive than grid search (might miss global optimum if poorly initialized)\n", |
| 1782 | + "- Uses scipy.optimize.minimize with L-BFGS-B method\n", |
| 1783 | + "\n", |
| 1784 | + "We'll demonstrate optimization using the ARIMAX error model and compare parameter recovery against grid search.\n" |
| 1785 | + ] |
| 1786 | + }, |
| 1787 | + { |
| 1788 | + "cell_type": "code", |
| 1789 | + "execution_count": null, |
| 1790 | + "metadata": {}, |
| 1791 | + "outputs": [], |
| 1792 | + "source": [ |
| 1793 | + "model_arimax_opt = cp.skl_models.TransferFunctionOLS(\n", |
| 1794 | + " saturation_type=None, # No saturation - adstock only\n", |
| 1795 | + " adstock_bounds={\n", |
| 1796 | + " \"half_life\": (0.5, 3.0), # Continuous range (same as grid: 0.5 to 3.0)\n", |
| 1797 | + " },\n", |
| 1798 | + " estimation_method=\"optimize\", # Continuous optimization\n", |
| 1799 | + " error_model=\"arimax\",\n", |
| 1800 | + " arima_order=(1, 0, 0),\n", |
| 1801 | + ")\n", |
| 1802 | + "\n", |
| 1803 | + "result_arimax_opt = cp.GradedInterventionTimeSeries(\n", |
| 1804 | + " data=df,\n", |
| 1805 | + " y_column=\"water_consumption\",\n", |
| 1806 | + " treatment_names=[\"comm_intensity\"],\n", |
| 1807 | + " base_formula=\"1 + t + temperature + rainfall\",\n", |
| 1808 | + " model=model_arimax_opt,\n", |
| 1809 | + ")\n", |
| 1810 | + "\n", |
| 1811 | + "print(\"Optimization complete!\")\n", |
| 1812 | + "print(f\"Best RMSE: {result_arimax_opt.transform_estimation_results['best_score']:.2f}\")\n", |
| 1813 | + "print(\n", |
| 1814 | + " f\"Estimated parameters: {result_arimax_opt.transform_estimation_results['best_params']}\"\n", |
| 1815 | + ")" |
| 1816 | + ] |
| 1817 | + }, |
| 1818 | + { |
| 1819 | + "cell_type": "markdown", |
| 1820 | + "metadata": {}, |
| 1821 | + "source": [ |
| 1822 | + "### Compare Transform Parameter Recovery\n" |
| 1823 | + ] |
| 1824 | + }, |
| 1825 | + { |
| 1826 | + "cell_type": "code", |
| 1827 | + "execution_count": null, |
| 1828 | + "metadata": {}, |
| 1829 | + "outputs": [], |
| 1830 | + "source": [ |
| 1831 | + "# Extract estimated parameters\n", |
| 1832 | + "# Grid search\n", |
| 1833 | + "half_life_grid = result_arimax.transform_estimation_results[\"best_params\"][\"half_life\"]\n", |
| 1834 | + "rmse_grid = result_arimax.transform_estimation_results[\"best_score\"]\n", |
| 1835 | + "\n", |
| 1836 | + "# Optimization\n", |
| 1837 | + "half_life_opt = result_arimax_opt.transform_estimation_results[\"best_params\"][\n", |
| 1838 | + " \"half_life\"\n", |
| 1839 | + "]\n", |
| 1840 | + "rmse_opt = result_arimax_opt.transform_estimation_results[\"best_score\"]\n", |
| 1841 | + "\n", |
| 1842 | + "# True value\n", |
| 1843 | + "half_life_true = 1.5\n", |
| 1844 | + "\n", |
| 1845 | + "# Create comparison table\n", |
| 1846 | + "comparison_data = {\n", |
| 1847 | + " \"Method\": [\"True Value\", \"ARIMAX Grid\", \"ARIMAX Optimize\"],\n", |
| 1848 | + " \"Half-life\": [\n", |
| 1849 | + " f\"{half_life_true:.3f}\",\n", |
| 1850 | + " f\"{half_life_grid:.3f}\",\n", |
| 1851 | + " f\"{half_life_opt:.3f}\",\n", |
| 1852 | + " ],\n", |
| 1853 | + " \"Error\": [\n", |
| 1854 | + " \"-\",\n", |
| 1855 | + " f\"{abs(half_life_grid - half_life_true):.3f}\",\n", |
| 1856 | + " f\"{abs(half_life_opt - half_life_true):.3f}\",\n", |
| 1857 | + " ],\n", |
| 1858 | + " \"RMSE\": [\"-\", f\"{rmse_grid:.2f}\", f\"{rmse_opt:.2f}\"],\n", |
| 1859 | + "}\n", |
| 1860 | + "\n", |
| 1861 | + "param_comparison_df = pd.DataFrame(comparison_data)\n", |
| 1862 | + "\n", |
| 1863 | + "print(\"=\" * 70)\n", |
| 1864 | + "print(\"PARAMETER RECOVERY: GRID vs OPTIMIZATION\")\n", |
| 1865 | + "print(\"=\" * 70)\n", |
| 1866 | + "print(param_comparison_df.to_string(index=False))\n", |
| 1867 | + "print(\"=\" * 70)\n", |
| 1868 | + "print()\n", |
| 1869 | + "print(\"KEY OBSERVATIONS:\")\n", |
| 1870 | + "print(f\"• True half-life: {half_life_true:.3f} weeks\")\n", |
| 1871 | + "print(\n", |
| 1872 | + " f\"• Grid search estimate: {half_life_grid:.3f} (error: {abs(half_life_grid - half_life_true):.3f})\"\n", |
| 1873 | + ")\n", |
| 1874 | + "print(\n", |
| 1875 | + " f\"• Optimization estimate: {half_life_opt:.3f} (error: {abs(half_life_opt - half_life_true):.3f})\"\n", |
| 1876 | + ")\n", |
| 1877 | + "print(\n", |
| 1878 | + " f\"• RMSE improvement: {rmse_grid - rmse_opt:.2f} ({(1 - rmse_opt / rmse_grid) * 100:.2f}%)\"\n", |
| 1879 | + ")\n", |
| 1880 | + "if abs(half_life_opt - half_life_true) < abs(half_life_grid - half_life_true):\n", |
| 1881 | + " print(\"✓ Optimization achieved better parameter recovery\")\n", |
| 1882 | + "else:\n", |
| 1883 | + " print(\"• Grid search achieved comparable or better parameter recovery\")" |
| 1884 | + ] |
| 1885 | + }, |
| 1886 | + { |
| 1887 | + "cell_type": "code", |
| 1888 | + "execution_count": null, |
| 1889 | + "metadata": {}, |
| 1890 | + "outputs": [], |
| 1891 | + "source": [ |
| 1892 | + "# Visualize adstock function comparison\n", |
| 1893 | + "fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n", |
| 1894 | + "\n", |
| 1895 | + "# Get adstock objects\n", |
| 1896 | + "adstock_true = result_arimax.treatments[0].adstock # Just to get structure\n", |
| 1897 | + "adstock_grid = result_arimax.treatments[0].adstock\n", |
| 1898 | + "adstock_opt = result_arimax_opt.treatments[0].adstock\n", |
| 1899 | + "\n", |
| 1900 | + "# Calculate weights\n", |
| 1901 | + "l_max = 8\n", |
| 1902 | + "lags = np.arange(l_max + 1)\n", |
| 1903 | + "\n", |
| 1904 | + "# True weights\n", |
| 1905 | + "alpha_true = np.power(0.5, 1 / half_life_true)\n", |
| 1906 | + "weights_true = alpha_true**lags\n", |
| 1907 | + "weights_true = weights_true / weights_true.sum()\n", |
| 1908 | + "\n", |
| 1909 | + "# Grid weights\n", |
| 1910 | + "alpha_grid = np.power(0.5, 1 / half_life_grid)\n", |
| 1911 | + "weights_grid = alpha_grid**lags\n", |
| 1912 | + "weights_grid = weights_grid / weights_grid.sum()\n", |
| 1913 | + "\n", |
| 1914 | + "# Optimize weights\n", |
| 1915 | + "alpha_opt = np.power(0.5, 1 / half_life_opt)\n", |
| 1916 | + "weights_opt = alpha_opt**lags\n", |
| 1917 | + "weights_opt = weights_opt / weights_opt.sum()\n", |
| 1918 | + "\n", |
| 1919 | + "# Plot\n", |
| 1920 | + "width = 0.25\n", |
| 1921 | + "ax.bar(\n", |
| 1922 | + " lags - width,\n", |
| 1923 | + " weights_true,\n", |
| 1924 | + " width,\n", |
| 1925 | + " alpha=0.8,\n", |
| 1926 | + " label=f\"True (half-life={half_life_true:.2f})\",\n", |
| 1927 | + " color=\"black\",\n", |
| 1928 | + ")\n", |
| 1929 | + "ax.bar(\n", |
| 1930 | + " lags,\n", |
| 1931 | + " weights_grid,\n", |
| 1932 | + " width,\n", |
| 1933 | + " alpha=0.8,\n", |
| 1934 | + " label=f\"Grid (half-life={half_life_grid:.2f})\",\n", |
| 1935 | + " color=\"C0\",\n", |
| 1936 | + ")\n", |
| 1937 | + "ax.bar(\n", |
| 1938 | + " lags + width,\n", |
| 1939 | + " weights_opt,\n", |
| 1940 | + " width,\n", |
| 1941 | + " alpha=0.8,\n", |
| 1942 | + " label=f\"Optimize (half-life={half_life_opt:.2f})\",\n", |
| 1943 | + " color=\"C2\",\n", |
| 1944 | + ")\n", |
| 1945 | + "\n", |
| 1946 | + "ax.set_xlabel(\"Lag (periods)\", fontsize=11)\n", |
| 1947 | + "ax.set_ylabel(\"Adstock Weight\", fontsize=11)\n", |
| 1948 | + "ax.set_title(\n", |
| 1949 | + " \"Adstock Parameter Recovery: Grid vs Optimization\", fontsize=12, fontweight=\"bold\"\n", |
| 1950 | + ")\n", |
| 1951 | + "ax.legend(fontsize=10, framealpha=0.9)\n", |
| 1952 | + "ax.grid(True, alpha=0.3, axis=\"y\")\n", |
| 1953 | + "\n", |
| 1954 | + "plt.tight_layout()\n", |
| 1955 | + "plt.show()\n", |
| 1956 | + "\n", |
| 1957 | + "print(\"\\n📊 INTERPRETATION:\")\n", |
| 1958 | + "print(\n", |
| 1959 | + " \"Continuous optimization can find parameter values between grid points, potentially\"\n", |
| 1960 | + ")\n", |
| 1961 | + "print(\n", |
| 1962 | + " \"achieving better fit (lower RMSE) and more accurate parameter recovery. The tradeoff\"\n", |
| 1963 | + ")\n", |
| 1964 | + "print(\n", |
| 1965 | + " \"is that optimization may find local optima, while grid search exhaustively evaluates\"\n", |
| 1966 | + ")\n", |
| 1967 | + "print(\n", |
| 1968 | + " \"all specified combinations. For this example, optimization explores the continuous\"\n", |
| 1969 | + ")\n", |
| 1970 | + "print(\"range [0.5, 3.0] rather than being limited to 30 discrete grid points.\")" |
| 1971 | + ] |
| 1972 | + }, |
| 1973 | + { |
| 1974 | + "cell_type": "markdown", |
| 1975 | + "metadata": {}, |
| 1976 | + "source": [ |
| 1977 | + "### Summary: Grid Search vs Optimization\n", |
| 1978 | + "\n", |
| 1979 | + "**When to use grid search:**\n", |
| 1980 | + "- You want exhaustive evaluation of discrete parameter combinations\n", |
| 1981 | + "- Parameter space is small enough to evaluate densely\n", |
| 1982 | + "- You want to visualize the full search landscape\n", |
| 1983 | + "- Robustness to local optima is critical\n", |
| 1984 | + "\n", |
| 1985 | + "**When to use optimization:**\n", |
| 1986 | + "- You want fine-grained continuous parameter estimates\n", |
| 1987 | + "- Parameter space is large (many parameters or wide ranges)\n", |
| 1988 | + "- You have good intuition for reasonable parameter ranges\n", |
| 1989 | + "- Computational efficiency matters for large datasets\n", |
| 1990 | + "\n", |
| 1991 | + "**Best practice:** Start with coarse grid search to understand the landscape, then use optimization to refine estimates if needed.\n" |
| 1992 | + ] |
| 1993 | + }, |
1765 | 1994 | { |
1766 | 1995 | "cell_type": "markdown", |
1767 | 1996 | "metadata": {}, |
|
0 commit comments