From 93d8e4865698c977f578849ebb7984d2660d6e6d Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Sun, 28 Sep 2025 15:24:39 -0400 Subject: [PATCH 01/23] ignore invalid value warnings in pyest --- pytest.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytest.ini b/pytest.ini index bdc4031dd..dc9e730ff 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,7 @@ # pytest.ini [pytest] +filterwarnings = + ignore::RuntimeWarning:.*invalid value encountered.* minversion = 6.0 testpaths = ./tests From eda76933bb695260ce911bff9510b9ade8ad96fc Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Sun, 28 Sep 2025 15:25:01 -0400 Subject: [PATCH 02/23] extract item to avoid numpy deprecation warning --- ogcore/SS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index 294abdf71..acbc816da 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -400,7 +400,7 @@ def inner_loop(outer_loop_vars, p, client): C_vec = np.zeros(p.I) K_demand_open_vec = np.zeros(p.M) for i_ind in range(p.I): - C_vec[i_ind] = aggr.get_C(c_i[i_ind, :, :], p, "SS") + C_vec[i_ind] = aggr.get_C(c_i[i_ind, :, :], p, "SS").item() Y_vec = np.dot(p.io_matrix.T, C_vec) for m_ind in range(p.M - 1): KYrat_m = firm.get_KY_ratio(r, p_m, p, "SS", m_ind) From 408f894f0422d0c86ff70771778a7da11b2483e6 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Sun, 28 Sep 2025 15:37:40 -0400 Subject: [PATCH 03/23] proper handle array to scalar conversion --- ogcore/SS.py | 2 +- tests/test_aggregates.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index acbc816da..ff0aa9593 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -967,7 +967,7 @@ def SS_solver( c_i_ss_mat[i_ind, :, :], p, "SS", - ) + ).item() ( total_tax_revenue, diff --git a/tests/test_aggregates.py b/tests/test_aggregates.py index 1fee8f612..9cfdf8cb8 100644 --- a/tests/test_aggregates.py +++ b/tests/test_aggregates.py @@ -31,7 +31,10 @@ for i in range(p.S): for k in range(p.J): L_loop[t, i, k] *= ( - p.omega[t, i] * p.lambdas[k] * n[t, i, k] * p.e[t, i, k] + p.omega[t, i].item() + * p.lambdas[k].item() + * n[t, i, k].item() + * p.e[t, i, k].item() ) expected1 = L_loop[-1, :, :].sum() expected2 = L_loop.sum(1).sum(1) From f46333c42504aa2f88a94b8d4b1a8568d395f8fc Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Sun, 28 Sep 2025 23:28:57 -0400 Subject: [PATCH 04/23] change how tighten image --- ogcore/parameter_plots.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ogcore/parameter_plots.py b/ogcore/parameter_plots.py index 27a4eb05f..9d0e5894d 100644 --- a/ogcore/parameter_plots.py +++ b/ogcore/parameter_plots.py @@ -49,13 +49,12 @@ def plot_imm_rates( "Source: " + source, fontsize=9, ) - plt.tight_layout(rect=(0, 0.035, 1, 1)) if include_title: plt.title("Immigration Rates") # Save or return figure if path: output_path = os.path.join(path, "imm_rates") - plt.savefig(output_path, dpi=300) + plt.savefig(output_path, bbox_inches="tight", dpi=300) plt.close() else: fig.show() @@ -407,11 +406,10 @@ def plot_fert_rates( "Source: " + source, fontsize=9, ) - plt.tight_layout(rect=(0, 0.035, 1, 1)) # Save or return figure if path: output_path = os.path.join(path, "fert_rates") - plt.savefig(output_path, dpi=300) + plt.savefig(output_path, bbox_inches="tight", dpi=300) plt.close() else: fig.show() From 773de052a0a72b5d6a4e84636cc342227c6ef0bb Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Sun, 28 Sep 2025 23:29:05 -0400 Subject: [PATCH 05/23] avoid divide by zero --- pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/pytest.ini b/pytest.ini index dc9e730ff..f7ea80290 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,6 +2,7 @@ [pytest] filterwarnings = ignore::RuntimeWarning:.*invalid value encountered.* + ignore::RuntimeWarning:.*divide by zero encountered in divide.* minversion = 6.0 testpaths = ./tests From 0c428ee06d78596dc104aec278c6d459e6ef612c Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Sun, 28 Sep 2025 23:30:50 -0400 Subject: [PATCH 06/23] update from deprecated cm map --- ogcore/parameter_plots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ogcore/parameter_plots.py b/ogcore/parameter_plots.py index 9d0e5894d..3d66a29dd 100644 --- a/ogcore/parameter_plots.py +++ b/ogcore/parameter_plots.py @@ -814,7 +814,7 @@ def txfunc_graph( None """ - cmap1 = matplotlib.cm.get_cmap("summer") + cmap1 = matplotlib.colormaps.get_cmap("summer") # Make comparison plot with full income domains fig = plt.figure() From af78abb8e402c6dd3355c12c07ecc0c70cfa5ab8 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Sun, 28 Sep 2025 23:33:41 -0400 Subject: [PATCH 07/23] avoid power warnings --- pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/pytest.ini b/pytest.ini index f7ea80290..4fdda3c2b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,6 +3,7 @@ filterwarnings = ignore::RuntimeWarning:.*invalid value encountered.* ignore::RuntimeWarning:.*divide by zero encountered in divide.* + ignore::RuntimeWarning:.*invalid value encountered in power.* minversion = 6.0 testpaths = ./tests From 2f1569c0ba2cac4713dee5cb2e11a2d745354fc5 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Sun, 28 Sep 2025 23:34:35 -0400 Subject: [PATCH 08/23] fix negation --- ogcore/parameter_tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ogcore/parameter_tables.py b/ogcore/parameter_tables.py index 2b87186a9..2a4100377 100644 --- a/ogcore/parameter_tables.py +++ b/ogcore/parameter_tables.py @@ -192,7 +192,7 @@ def param_table(p, table_format="tex", path=None): table["Symbol"].append(v[1]) table["Description"].append(v[0]) value = getattr(p, k) - if hasattr(value, "__len__") & ~isinstance(value, str): + if hasattr(value, "__len__") and not isinstance(value, str): if value.ndim > 1: report = ( "Too large to report here, see default parameters JSON" From 6876b4a0fa753d258c74703719f459014e74fe21 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Fri, 3 Oct 2025 10:01:57 -0400 Subject: [PATCH 09/23] Improve dask scheduler configuration and timeout handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add timeout (600s) and error handling to TPI.py client.gather() with fallback to serial computation - Increase SS.py client.gather() timeout from 300s to 600s for consistency - Update SS.py to use logging.warning instead of print for better consistency - Adjust LocalCluster timeouts in test_TPI.py fixture: - Increase memory_limit from 2GB to 3GB - Increase communication timeout from 300s to 900s - Increase heartbeat_interval from 10s to 30s - Increase death_timeout from 60s to 120s These changes address issues with long-running computationally intensive tasks where communication can break down due to overly aggressive timeouts. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- ogcore/SS.py | 8 +++++--- ogcore/TPI.py | 24 +++++++++++++++++++++++- tests/test_TPI.py | 8 ++++---- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index 63467ad75..eb69dd89d 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -307,11 +307,13 @@ def inner_loop(outer_loop_vars, p, client): futures.append(f) try: - results = client.gather(futures, timeout=300) + results = client.gather(futures, timeout=600) except Exception as e: # Cancel remaining futures and fall back to serial computation - print( - f"Dask computation failed ({e}), falling back to serial computation" + import logging + logging.warning( + f"Dask client.gather() failed with error: {e}. " + "Falling back to serial computation." ) for f in futures: f.cancel() diff --git a/ogcore/TPI.py b/ogcore/TPI.py index 5ac4ed47e..21b06cee9 100644 --- a/ogcore/TPI.py +++ b/ogcore/TPI.py @@ -777,7 +777,29 @@ def run_TPI(p, client=None): scattered_p_future, ) futures.append(f) - results = client.gather(futures) + try: + results = client.gather(futures, timeout=600) + except Exception as e: + # Cancel remaining futures and fall back to serial computation + logging.warning( + f"Dask client.gather() failed with error: {e}. " + "Falling back to serial computation for this iteration." + ) + for future in futures: + future.cancel() + results = [] + for j in range(p.J): + guesses = (guesses_b[:, :, j], guesses_n[:, :, j]) + res = inner_loop( + guesses, + outer_loop_vars, + initial_values, + ubi, + j, + ind, + p, + ) + results.append(res) else: # Serial fallback (no dask client) for j in range(p.J): diff --git a/tests/test_TPI.py b/tests/test_TPI.py index faf9fb1c4..e04863f2c 100644 --- a/tests/test_TPI.py +++ b/tests/test_TPI.py @@ -169,10 +169,10 @@ def dask_client(): cluster = LocalCluster( n_workers=NUM_WORKERS, threads_per_worker=2, - memory_limit="2GB", - timeout="300s", - heartbeat_interval="10s", - death_timeout="60s", + memory_limit="3GB", + timeout="900s", + heartbeat_interval="30s", + death_timeout="120s", ) client = Client(cluster) yield client From 91c5f4e4c11870c196aa2575cac85778ad88f336 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Fri, 3 Oct 2025 10:07:02 -0400 Subject: [PATCH 10/23] Fix dask timeout implementation using wait() instead of gather(timeout=) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous commit incorrectly used client.gather(timeout=) which is not a valid parameter. This commit fixes the implementation to use the correct approach: - Use distributed.wait() with timeout to wait for futures to complete - Check if any futures did not complete (not_done) and raise TimeoutError - Only call client.gather() after confirming all futures are done - Maintain the same error handling and fallback to serial computation This properly implements timeout handling for long-running dask computations in both TPI.py and SS.py. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- ogcore/SS.py | 12 ++++++++++-- ogcore/TPI.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index eb69dd89d..5b93834ae 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -307,12 +307,20 @@ def inner_loop(outer_loop_vars, p, client): futures.append(f) try: - results = client.gather(futures, timeout=600) + # Wait for futures with timeout, then gather results + from distributed import wait + done, not_done = wait(futures, timeout=600) + if not_done: + # Some futures didn't complete in time + raise TimeoutError( + f"{len(not_done)} futures did not complete within 600 seconds" + ) + results = client.gather(futures) except Exception as e: # Cancel remaining futures and fall back to serial computation import logging logging.warning( - f"Dask client.gather() failed with error: {e}. " + f"Dask computation failed with error: {e}. " "Falling back to serial computation." ) for f in futures: diff --git a/ogcore/TPI.py b/ogcore/TPI.py index 21b06cee9..3e4c3b7c2 100644 --- a/ogcore/TPI.py +++ b/ogcore/TPI.py @@ -778,11 +778,19 @@ def run_TPI(p, client=None): ) futures.append(f) try: - results = client.gather(futures, timeout=600) + # Wait for futures with timeout, then gather results + from distributed import wait + done, not_done = wait(futures, timeout=600) + if not_done: + # Some futures didn't complete in time + raise TimeoutError( + f"{len(not_done)} futures did not complete within 600 seconds" + ) + results = client.gather(futures) except Exception as e: # Cancel remaining futures and fall back to serial computation logging.warning( - f"Dask client.gather() failed with error: {e}. " + f"Dask computation failed with error: {e}. " "Falling back to serial computation for this iteration." ) for future in futures: From 73f37250aab8275e0a963be67c21e9fb087c6892 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Mon, 6 Oct 2025 08:50:38 -0400 Subject: [PATCH 11/23] use consistent superscript --- docs/book/content/theory/government.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/book/content/theory/government.md b/docs/book/content/theory/government.md index b1daaa0c0..5bbe46a32 100644 --- a/docs/book/content/theory/government.md +++ b/docs/book/content/theory/government.md @@ -542,7 +542,7 @@ Total pension spending is the sum of the pension payments to each household in t (SecUBI_NonGrowthAdj)= ###### UBI specification not adjusted for economic growth - A non-growth adjusted UBI (`ubi_growthadj = False`) is one in which the initial nonstationary nominal-valued $t=0$ UBI matrix $ubi^{\$}_{j,s,t=0}$ does not grow, while the economy's long-run growth rate is $g_y$ for the most common parameterization is positive ($g_y>0$). + A non-growth adjusted UBI (`ubi_growthadj = False`) is one in which the initial nonstationary nominal-valued $t=0$ UBI matrix $ubi^{nom}_{j,s,t=0}$ does not grow, while the economy's long-run growth rate is $g_y$ for the most common parameterization is positive ($g_y>0$). ```{math} :label: EqUBIubi_nom_NonGrwAdj_jst From 8eff47bdff0193c0c6fb64f9090e3379753b5551 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Wed, 22 Oct 2025 12:14:09 -0400 Subject: [PATCH 12/23] remove duplicate code from SS.run_SS --- ogcore/SS.py | 116 ++++++++++++++++++--------------------------------- 1 file changed, 41 insertions(+), 75 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index 5b93834ae..f3fe22277 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -1534,84 +1534,50 @@ def run_SS(p, client=None): if p.baseline_spending: TR_baseline = TRguess Ig_baseline = ss_solutions["I_g"] - ss_params_reform = ( - b_guess, - n_guess, - TR_baseline, - Ig_baseline, - factor, - p, - client, - ) - if p.use_zeta: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess, BQguess, TR_baseline] - ) - else: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess] - + list(BQguess) - + [TR_baseline] - ) - sol = opt.root( - SS_fsolve, - guesses, - args=ss_params_reform, - method=p.SS_root_method, - tol=p.mindist_SS, - ) - r_p_ss = sol.x[0] - rss = sol.x[1] - wss = sol.x[2] - p_m_ss = sol.x[3 : 3 + p.M] - Yss = sol.x[3 + p.M] - BQss = sol.x[3 + p.M + 1 : -1] - TR_ss = sol.x[-1] else: - ss_params_reform = ( - b_guess, - n_guess, - None, - None, - factor, - p, - client, + TR_baseline = None + Ig_baseline = None + # Now solve for the steady state of the reform + ss_params_reform = ( + b_guess, + n_guess, + TR_baseline, + Ig_baseline, + factor, + p, + client, + ) + if p.use_zeta: + guesses = ( + [r_p_guess, rguess, wguess] + + list(p_m_guess) + + [Yguess, BQguess, TRguess] ) - if p.use_zeta: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess, BQguess, TRguess] - ) - else: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess] - + list(BQguess) - + [TRguess] - ) - sol = opt.root( - SS_fsolve, - guesses, - args=ss_params_reform, - method=p.SS_root_method, - tol=p.mindist_SS, + else: + guesses = ( + [r_p_guess, rguess, wguess] + + list(p_m_guess) + + [Yguess] + + list(BQguess) + + [TRguess] ) - r_p_ss = sol.x[0] - rss = sol.x[1] - wss = sol.x[2] - p_m_ss = sol.x[3 : 3 + p.M] - Yss = sol.x[3 + p.M] - BQss = sol.x[3 + p.M + 1 : -1] - TR_ss = sol.x[-1] - Yss = TR_ss / p.alpha_T[-1] # may not be right - if - # budget_balance = True, but that's ok - will be fixed in - # SS_solver + sol = opt.root( + SS_fsolve, + guesses, + args=ss_params_reform, + method=p.SS_root_method, + tol=p.mindist_SS, + ) + r_p_ss = sol.x[0] + rss = sol.x[1] + wss = sol.x[2] + p_m_ss = sol.x[3 : 3 + p.M] + Yss = sol.x[3 + p.M] + BQss = sol.x[3 + p.M + 1 : -1] + TR_ss = sol.x[-1] + Yss = TR_ss / p.alpha_T[-1] # may not be right - if + # budget_balance = True, but that's ok - will be fixed in + # SS_solver if ( (ENFORCE_SOLUTION_CHECKS) and not (sol.success == 1) From 5ac2f71376b3fd7f00d424c8eaa65e0dcfb67a19 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Wed, 22 Oct 2025 14:56:14 -0400 Subject: [PATCH 13/23] format --- ogcore/SS.py | 2 ++ ogcore/TPI.py | 1 + 2 files changed, 3 insertions(+) diff --git a/ogcore/SS.py b/ogcore/SS.py index f3fe22277..880b62e7c 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -309,6 +309,7 @@ def inner_loop(outer_loop_vars, p, client): try: # Wait for futures with timeout, then gather results from distributed import wait + done, not_done = wait(futures, timeout=600) if not_done: # Some futures didn't complete in time @@ -319,6 +320,7 @@ def inner_loop(outer_loop_vars, p, client): except Exception as e: # Cancel remaining futures and fall back to serial computation import logging + logging.warning( f"Dask computation failed with error: {e}. " "Falling back to serial computation." diff --git a/ogcore/TPI.py b/ogcore/TPI.py index 3e4c3b7c2..db31f8bea 100644 --- a/ogcore/TPI.py +++ b/ogcore/TPI.py @@ -780,6 +780,7 @@ def run_TPI(p, client=None): try: # Wait for futures with timeout, then gather results from distributed import wait + done, not_done = wait(futures, timeout=600) if not_done: # Some futures didn't complete in time From 7c68081a99ef37edb3ad9a3a540aaa2c47168a1b Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Wed, 22 Oct 2025 16:00:28 -0400 Subject: [PATCH 14/23] remove more redundant lines in run_SS --- ogcore/SS.py | 143 +++++++++++++++++++++------------------------------ 1 file changed, 58 insertions(+), 85 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index 880b62e7c..c94879a14 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -1400,7 +1400,7 @@ def run_SS(p, client=None): Yguess = TRguess / p.alpha_T[-1] factorguess = p.initial_guess_factor_SS BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) - ss_params_baseline = ( + ss_params = ( b_guess, n_guess, None, @@ -1424,45 +1424,6 @@ def run_SS(p, client=None): + list(BQguess) + [TRguess, factorguess] ) - sol = opt.root( - SS_fsolve, - guesses, - args=ss_params_baseline, - method=p.SS_root_method, - tol=p.mindist_SS, - ) - if sol.success: - SS_solved = True - break - if ENFORCE_SOLUTION_CHECKS and not sol.success: - raise RuntimeError("Steady state equilibrium not found") - r_p_ss = sol.x[0] - rss = sol.x[1] - wss = sol.x[2] - p_m_ss = sol.x[3 : 3 + p.M] - Yss = sol.x[3 + p.M] - BQss = sol.x[3 + p.M + 1 : -2] - TR_ss = sol.x[-2] - factor_ss = sol.x[-1] - Yss = TR_ss / p.alpha_T[-1] # may not be right - if budget_balance - # # = True, but that's ok - will be fixed in SS_solver - fsolve_flag = True - output = SS_solver( - b_guess, - n_guess, - r_p_ss, - rss, - wss, - p_m_ss, - Yss, - BQss, - TR_ss, - None, - factor_ss, - p, - client, - fsolve_flag, - ) else: # Use the baseline solution to get starting values for the reform baseline_ss_path = os.path.join(p.baseline_dir, "SS", "SS_vars.pkl") @@ -1487,7 +1448,7 @@ def run_SS(p, client=None): BQguess, TRguess, Yguess, - factor, + factor_ss, ) = ( ss_solutions["b_sp1"], ss_solutions["n"], @@ -1529,7 +1490,9 @@ def run_SS(p, client=None): p_m_guess = np.ones(p.M) TRguess = p.initial_guess_TR_SS Yguess = TRguess / p.alpha_T[-1] - factor = p.initial_guess_factor_SS + factor_ss = ss_solutions[ + "factor" + ] # don't guess factor, use baseline BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) if p.use_zeta: BQguess = 0.12231465279007188 @@ -1540,12 +1503,12 @@ def run_SS(p, client=None): TR_baseline = None Ig_baseline = None # Now solve for the steady state of the reform - ss_params_reform = ( + ss_params = ( b_guess, n_guess, TR_baseline, Ig_baseline, - factor, + factor_ss, p, client, ) @@ -1563,13 +1526,26 @@ def run_SS(p, client=None): + list(BQguess) + [TRguess] ) - sol = opt.root( - SS_fsolve, - guesses, - args=ss_params_reform, - method=p.SS_root_method, - tol=p.mindist_SS, - ) + # Solve for steady state using root finder + sol = opt.root( + SS_fsolve, + guesses, + args=ss_params, + method=p.SS_root_method, + tol=p.mindist_SS, + ) + if p.baseline: + r_p_ss = sol.x[0] + rss = sol.x[1] + wss = sol.x[2] + p_m_ss = sol.x[3 : 3 + p.M] + Yss = sol.x[3 + p.M] + BQss = sol.x[3 + p.M + 1 : -2] + TR_ss = sol.x[-2] + factor_ss = sol.x[-1] + Yss = TR_ss / p.alpha_T[-1] # may not be right - if budget_balance + # # = True, but that's ok - will be fixed in SS_solver + else: r_p_ss = sol.x[0] rss = sol.x[1] wss = sol.x[2] @@ -1580,39 +1556,36 @@ def run_SS(p, client=None): Yss = TR_ss / p.alpha_T[-1] # may not be right - if # budget_balance = True, but that's ok - will be fixed in # SS_solver - if ( - (ENFORCE_SOLUTION_CHECKS) - and not (sol.success == 1) - and (np.absolute(np.array(sol.fun)).max() > p.mindist_SS) - ): - raise RuntimeError("Steady state equilibrium not found") - # Return SS values of variables - fsolve_flag = True - # Return SS values of variables - if not p.baseline_spending: - Ig_baseline = None - output = SS_solver( - b_guess, - n_guess, - r_p_ss, - rss, - wss, - p_m_ss, - Yss, - BQss, - TR_ss, - Ig_baseline, - factor, - p, - client, - fsolve_flag, + if ENFORCE_SOLUTION_CHECKS and not sol.success: + raise RuntimeError("Steady state equilibrium not found") + # Trigger flag that model has been solved + fsolve_flag = True + # Return SS values of variables + if p.baseline or not p.baseline_spending: + Ig_baseline = None + output = SS_solver( + b_guess, + n_guess, + r_p_ss, + rss, + wss, + p_m_ss, + Yss, + BQss, + TR_ss, + Ig_baseline, + factor_ss, + p, + client, + fsolve_flag, + ) + if output["G"] < 0.0: + warnings.warn( + "Warning: The combination of the tax policy " + + "you specified and your target debt-to-GDP " + + "ratio results in an infeasible amount of " + + "government spending in order to close the " + + "budget (i.e., G < 0)" ) - if output["G"] < 0.0: - warnings.warn( - "Warning: The combination of the tax policy " - + "you specified and your target debt-to-GDP " - + "ratio results in an infeasible amount of " - + "government spending in order to close the " - + "budget (i.e., G < 0)" - ) + return output From b53c32243d1bdd866fc6aa09cc883975575f1909 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Wed, 22 Oct 2025 19:26:56 -0400 Subject: [PATCH 15/23] make function for initial guesses --- ogcore/SS.py | 149 +++++++++++++++++++++++++-------------------------- 1 file changed, 73 insertions(+), 76 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index c94879a14..c2c1a0115 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -1317,6 +1317,53 @@ def SS_fsolve(guesses, *args): return errors +def SS_initial_guesses(p, b_val=0.0055, n_val=0.4, r_tr_scalars=[1.0, 1.0]): + """ + Finds the initial guesses for b, n and for the steady state outer + loop variables. + + Args: + p (OG-Core Specifications object): model parameters + + Returns: + guesses (list): initial guesses for outer loop variables + """ + r_p_guess = r_tr_scalars[0] * p.initial_guess_r_SS + rguess = r_tr_scalars[0] * p.initial_guess_r_SS + wguess = firm.get_w_from_r(rguess, p, "SS") + p_m_guess = np.ones(p.M) + TRguess = r_tr_scalars[1] * p.initial_guess_TR_SS + Yguess = TRguess / p.alpha_T[-1] + + # create guesses list + # Note that BQ is an vector of lenght J if use_zeta=False + if p.use_zeta: + b_guess = np.ones((p.S, p.J)) * b_val + n_guess = np.ones((p.S, p.J)) * n_val * p.ltilde + BQguess = 0.12231465279007188 + guesses = ( + [r_p_guess, rguess, wguess] + + list(p_m_guess) + + [Yguess, BQguess, TRguess] + ) + else: + b_guess = np.ones((p.S, p.J)) * 0.07 # TODO: remove hardcode here and next line + n_guess = np.ones((p.S, p.J)) * 0.35 * p.ltilde + BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) + guesses = ( + [r_p_guess, rguess, wguess] + + list(p_m_guess) + + [Yguess] + + list(BQguess) + + [TRguess] + ) + # append factor if baseline + if p.baseline: + guesses.append(p.initial_guess_factor_SS) + + return guesses, b_guess, n_guess + + def run_SS(p, client=None): """ Solve for steady-state equilibrium of OG-Core. @@ -1330,6 +1377,7 @@ def run_SS(p, client=None): results """ + # TODO: move the list below to constants.py # Create list of deviation factors for initial guesses of r and TR dev_factor_list = [ [1.00, 1.0], @@ -1386,20 +1434,9 @@ def run_SS(p, client=None): f"SS using initial guess factors for r and TR of " + f"{v[0]} and {v[1]} respectively." ) - r_p_guess = v[0] * p.initial_guess_r_SS - rguess = v[0] * p.initial_guess_r_SS - if p.use_zeta: - b_guess = np.ones((p.S, p.J)) * 0.0055 - n_guess = np.ones((p.S, p.J)) * 0.4 * p.ltilde - else: - b_guess = np.ones((p.S, p.J)) * 0.07 - n_guess = np.ones((p.S, p.J)) * 0.35 * p.ltilde - wguess = firm.get_w_from_r(rguess, p, "SS") - p_m_guess = np.ones(p.M) - TRguess = v[1] * p.initial_guess_TR_SS - Yguess = TRguess / p.alpha_T[-1] - factorguess = p.initial_guess_factor_SS - BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) + guesses, b_guess, n_guess = SS_initial_guesses( + p, r_tr_scalars=v + ) ss_params = ( b_guess, n_guess, @@ -1409,21 +1446,6 @@ def run_SS(p, client=None): p, client, ) - if p.use_zeta: - BQguess = 0.12231465279007188 - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess, BQguess, TRguess, factorguess] - ) - else: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess] - + list(BQguess) - + [TRguess, factorguess] - ) else: # Use the baseline solution to get starting values for the reform baseline_ss_path = os.path.join(p.baseline_dir, "SS", "SS_vars.pkl") @@ -1455,15 +1477,21 @@ def run_SS(p, client=None): float(ss_solutions["r_p"]), float(ss_solutions["r"]), float(ss_solutions["w"]), - ss_solutions[ - "p_m" - ], # Not sure why need to index p_m,but otherwise its shape is off.. + ss_solutions["p_m"], ss_solutions["BQ"], float(ss_solutions["TR"]), float(ss_solutions["Y"]), ss_solutions["factor"], ) use_new_guesses = False + BQ_items = [BQguess] if p.use_zeta else list(BQguess) + guesses = ( + [r_p_guess, rguess, wguess] + + list(p_m_guess) + + [Yguess] + + BQ_items + + [TRguess] + ) else: logging.warning( "Dimensions of previous solutions for SS do not match" @@ -1478,26 +1506,15 @@ def run_SS(p, client=None): logging.info("Using new guesses for SS") use_new_guesses = True if use_new_guesses: - if p.use_zeta: - b_guess = np.ones((p.S, p.J)) * 0.0055 - n_guess = np.ones((p.S, p.J)) * 0.4 * p.ltilde - else: - b_guess = np.ones((p.S, p.J)) * 0.07 - n_guess = np.ones((p.S, p.J)) * 0.35 * p.ltilde - r_p_guess = p.initial_guess_r_SS - rguess = p.initial_guess_r_SS - wguess = firm.get_w_from_r(rguess, p, "SS") - p_m_guess = np.ones(p.M) - TRguess = p.initial_guess_TR_SS - Yguess = TRguess / p.alpha_T[-1] + # TODO: think about if add loop over dev factors here too + guesses, b_guess, n_guess = SS_initial_guesses( + p + ) factor_ss = ss_solutions[ "factor" ] # don't guess factor, use baseline - BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) - if p.use_zeta: - BQguess = 0.12231465279007188 if p.baseline_spending: - TR_baseline = TRguess + TR_baseline = ss_solutions["TR"] Ig_baseline = ss_solutions["I_g"] else: TR_baseline = None @@ -1512,20 +1529,6 @@ def run_SS(p, client=None): p, client, ) - if p.use_zeta: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess, BQguess, TRguess] - ) - else: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess] - + list(BQguess) - + [TRguess] - ) # Solve for steady state using root finder sol = opt.root( SS_fsolve, @@ -1534,28 +1537,22 @@ def run_SS(p, client=None): method=p.SS_root_method, tol=p.mindist_SS, ) + r_p_ss = sol.x[0] + rss = sol.x[1] + wss = sol.x[2] + p_m_ss = sol.x[3 : 3 + p.M] if p.baseline: - r_p_ss = sol.x[0] - rss = sol.x[1] - wss = sol.x[2] - p_m_ss = sol.x[3 : 3 + p.M] - Yss = sol.x[3 + p.M] BQss = sol.x[3 + p.M + 1 : -2] TR_ss = sol.x[-2] factor_ss = sol.x[-1] - Yss = TR_ss / p.alpha_T[-1] # may not be right - if budget_balance - # # = True, but that's ok - will be fixed in SS_solver + Yss = TR_ss / p.alpha_T[-1] # may not be right - if + # budget_balance = True, but that's ok - will be fixed in + # SS_solver else: - r_p_ss = sol.x[0] - rss = sol.x[1] - wss = sol.x[2] - p_m_ss = sol.x[3 : 3 + p.M] Yss = sol.x[3 + p.M] BQss = sol.x[3 + p.M + 1 : -1] TR_ss = sol.x[-1] - Yss = TR_ss / p.alpha_T[-1] # may not be right - if - # budget_balance = True, but that's ok - will be fixed in - # SS_solver + if ENFORCE_SOLUTION_CHECKS and not sol.success: raise RuntimeError("Steady state equilibrium not found") # Trigger flag that model has been solved From 411af2e1279926178cd7bde90c651e0ad5a44dc1 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Wed, 22 Oct 2025 20:10:44 -0400 Subject: [PATCH 16/23] move dev factors to constants --- ogcore/SS.py | 50 +++------------------------------------------ ogcore/constants.py | 44 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index c2c1a0115..d86a0cdff 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -5,7 +5,7 @@ import dask.multiprocessing from ogcore import tax, pensions, household, firm, utils, fiscal from ogcore import aggregates as aggr -from ogcore.constants import SHOW_RUNTIME +from ogcore.constants import SHOW_RUNTIME, DEV_FACTOR_LIST from ogcore import config import os import warnings @@ -1377,50 +1377,6 @@ def run_SS(p, client=None): results """ - # TODO: move the list below to constants.py - # Create list of deviation factors for initial guesses of r and TR - dev_factor_list = [ - [1.00, 1.0], - [0.95, 1.0], - [1.05, 1.0], - [0.90, 1.0], - [1.10, 1.0], - [0.85, 1.0], - [1.15, 1.0], - [0.80, 1.0], - [1.20, 1.0], - [0.75, 1.0], - [1.25, 1.0], - [0.70, 1.0], - [1.30, 1.0], - [1.00, 0.2], - [0.95, 0.2], - [1.05, 0.2], - [0.90, 0.2], - [1.10, 0.2], - [0.85, 0.2], - [1.15, 0.2], - [0.80, 0.2], - [1.20, 0.2], - [0.75, 0.2], - [1.25, 0.2], - [0.70, 0.2], - [1.30, 0.2], - [1.00, 0.6], - [0.95, 0.6], - [1.05, 0.6], - [0.90, 0.6], - [1.10, 0.6], - [0.85, 0.6], - [1.15, 0.6], - [0.80, 0.6], - [1.20, 0.6], - [0.75, 0.6], - [1.25, 0.6], - [0.70, 0.6], - [1.30, 0.6], - ] - # For initial guesses of w, r, TR, and factor, we use values that # are close to some steady state values. if p.baseline: @@ -1428,8 +1384,8 @@ def run_SS(p, client=None): # gone through all guesses. This should usually solve in the first guess SS_solved = False k = 0 - while not SS_solved and k < len(dev_factor_list) - 1: - for k, v in enumerate(dev_factor_list): + while not SS_solved and k < len(DEV_FACTOR_LIST) - 1: + for k, v in enumerate(DEV_FACTOR_LIST): logging.info( f"SS using initial guess factors for r and TR of " + f"{v[0]} and {v[1]} respectively." diff --git a/ogcore/constants.py b/ogcore/constants.py index 8cffeb548..4c2020154 100644 --- a/ogcore/constants.py +++ b/ogcore/constants.py @@ -218,3 +218,47 @@ # Ignoring the following: # 'starting_age', 'ending_age', 'constant_demographics', # 'constant_rates', 'zero_taxes' + +# List of deviation factors for initial guesses of r and TR used in +# SS.run_SS for a more robust SS solution +DEV_FACTOR_LIST = [ + [1.00, 1.0], + [0.95, 1.0], + [1.05, 1.0], + [0.90, 1.0], + [1.10, 1.0], + [0.85, 1.0], + [1.15, 1.0], + [0.80, 1.0], + [1.20, 1.0], + [0.75, 1.0], + [1.25, 1.0], + [0.70, 1.0], + [1.30, 1.0], + [1.00, 0.2], + [0.95, 0.2], + [1.05, 0.2], + [0.90, 0.2], + [1.10, 0.2], + [0.85, 0.2], + [1.15, 0.2], + [0.80, 0.2], + [1.20, 0.2], + [0.75, 0.2], + [1.25, 0.2], + [0.70, 0.2], + [1.30, 0.2], + [1.00, 0.6], + [0.95, 0.6], + [1.05, 0.6], + [0.90, 0.6], + [1.10, 0.6], + [0.85, 0.6], + [1.15, 0.6], + [0.80, 0.6], + [1.20, 0.6], + [0.75, 0.6], + [1.25, 0.6], + [0.70, 0.6], + [1.30, 0.6], +] \ No newline at end of file From 19baba774184b78baf345dc647133311949d0bcd Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Thu, 23 Oct 2025 16:59:08 -0400 Subject: [PATCH 17/23] add docstring --- ogcore/SS.py | 187 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 113 insertions(+), 74 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index d86a0cdff..343098a02 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -1324,9 +1324,14 @@ def SS_initial_guesses(p, b_val=0.0055, n_val=0.4, r_tr_scalars=[1.0, 1.0]): Args: p (OG-Core Specifications object): model parameters + b_val (float): initial guess value for savings + n_val (float): initial guess value for labor supply + r_tr_scalars (list): scalars to adjust initial guesses for r and TR Returns: guesses (list): initial guesses for outer loop variables + b_guess (ndarray): initial guess for savings + n_guess (ndarray): initial guess for labor supply """ r_p_guess = r_tr_scalars[0] * p.initial_guess_r_SS rguess = r_tr_scalars[0] * p.initial_guess_r_SS @@ -1341,25 +1346,23 @@ def SS_initial_guesses(p, b_val=0.0055, n_val=0.4, r_tr_scalars=[1.0, 1.0]): b_guess = np.ones((p.S, p.J)) * b_val n_guess = np.ones((p.S, p.J)) * n_val * p.ltilde BQguess = 0.12231465279007188 - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess, BQguess, TRguess] - ) else: - b_guess = np.ones((p.S, p.J)) * 0.07 # TODO: remove hardcode here and next line + b_guess = ( + np.ones((p.S, p.J)) * 0.07 + ) # TODO: remove hardcode here and next line n_guess = np.ones((p.S, p.J)) * 0.35 * p.ltilde BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess] - + list(BQguess) - + [TRguess] - ) - # append factor if baseline - if p.baseline: - guesses.append(p.initial_guess_factor_SS) + # append factor guess if baseline + BQ_items = [BQguess] if p.use_zeta else list(BQguess) + guesses = ( + [r_p_guess, rguess, wguess] + + list(p_m_guess) + + [Yguess] + + BQ_items + + [TRguess] + ) + if p.baseline: + guesses.append(p.initial_guess_factor_SS) return guesses, b_guess, n_guess @@ -1379,30 +1382,11 @@ def run_SS(p, client=None): """ # For initial guesses of w, r, TR, and factor, we use values that # are close to some steady state values. - if p.baseline: - # Loop over initial guesses of r and TR until find a solution or until have - # gone through all guesses. This should usually solve in the first guess - SS_solved = False - k = 0 - while not SS_solved and k < len(DEV_FACTOR_LIST) - 1: - for k, v in enumerate(DEV_FACTOR_LIST): - logging.info( - f"SS using initial guess factors for r and TR of " - + f"{v[0]} and {v[1]} respectively." - ) - guesses, b_guess, n_guess = SS_initial_guesses( - p, r_tr_scalars=v - ) - ss_params = ( - b_guess, - n_guess, - None, - None, - None, - p, - client, - ) - else: + print( + "Using baseline solution as initial guess:", + p.reform_use_baseline_solution, + ) + if p.baseline is False and p.reform_use_baseline_solution: # Use the baseline solution to get starting values for the reform baseline_ss_path = os.path.join(p.baseline_dir, "SS", "SS_vars.pkl") ss_solutions = utils.safe_read_pickle(baseline_ss_path) @@ -1440,7 +1424,15 @@ def run_SS(p, client=None): ss_solutions["factor"], ) use_new_guesses = False + if p.baseline_spending: + TR_baseline = TRguess + Ig_baseline = ss_solutions["I_g"] + else: + TR_baseline = None + Ig_baseline = None BQ_items = [BQguess] if p.use_zeta else list(BQguess) + print("Baselines spending:", p.baseline_spending) + print("Use zeta:", p.use_zeta) guesses = ( [r_p_guess, rguess, wguess] + list(p_m_guess) @@ -1448,6 +1440,25 @@ def run_SS(p, client=None): + BQ_items + [TRguess] ) + # Now solve for the steady state of the reform + ss_params = ( + b_guess, + n_guess, + TR_baseline, + Ig_baseline, + factor_ss, + p, + client, + ) + + # Solve for steady state using root finder + sol = opt.root( + SS_fsolve, + guesses, + args=ss_params, + method=p.SS_root_method, + tol=p.mindist_SS, + ) else: logging.warning( "Dimensions of previous solutions for SS do not match" @@ -1458,41 +1469,65 @@ def run_SS(p, client=None): "KeyError: previous solutions for SS not found" ) use_new_guesses = True - else: - logging.info("Using new guesses for SS") - use_new_guesses = True - if use_new_guesses: - # TODO: think about if add loop over dev factors here too - guesses, b_guess, n_guess = SS_initial_guesses( - p + if ( + p.baseline + or p.reform_use_baseline_solution is False + or use_new_guesses + ): + # Loop over initial guesses of r and TR until find a solution or until have + # gone through all guesses. This should usually solve in the first guess + SS_solved = False + k = 0 + while not SS_solved and k < len(DEV_FACTOR_LIST) - 1: + for k, v in enumerate(DEV_FACTOR_LIST): + logging.info( + f"SS using initial guess factors for r and TR of " + + f"{v[0]} and {v[1]} respectively." ) - factor_ss = ss_solutions[ - "factor" - ] # don't guess factor, use baseline - if p.baseline_spending: - TR_baseline = ss_solutions["TR"] - Ig_baseline = ss_solutions["I_g"] - else: - TR_baseline = None - Ig_baseline = None - # Now solve for the steady state of the reform - ss_params = ( - b_guess, - n_guess, - TR_baseline, - Ig_baseline, - factor_ss, - p, - client, - ) - # Solve for steady state using root finder - sol = opt.root( - SS_fsolve, - guesses, - args=ss_params, - method=p.SS_root_method, - tol=p.mindist_SS, - ) + guesses, b_guess, n_guess = SS_initial_guesses( + p, r_tr_scalars=v + ) + ss_params = ( + b_guess, + n_guess, + None, + None, + None, + p, + client, + ) + if p.baseline: + factor_ss = None + else: + factor_ss = ss_solutions[ + "factor" + ] # don't guess factor, use baseline + if p.baseline_spending: + TR_baseline = ss_solutions["TR"] + Ig_baseline = ss_solutions["I_g"] + else: + TR_baseline = None + Ig_baseline = None + ss_params = ( + b_guess, + n_guess, + TR_baseline, + Ig_baseline, + factor_ss, + p, + client, + ) + # Solve for steady state using root finder + sol = opt.root( + SS_fsolve, + guesses, + args=ss_params, + method=p.SS_root_method, + tol=p.mindist_SS, + ) + if sol.success: + SS_solved = True + break r_p_ss = sol.x[0] rss = sol.x[1] wss = sol.x[2] @@ -1508,6 +1543,10 @@ def run_SS(p, client=None): Yss = sol.x[3 + p.M] BQss = sol.x[3 + p.M + 1 : -1] TR_ss = sol.x[-1] + if not p.baseline_spending: + Yss = TR_ss / p.alpha_T[-1] # may not be right - if + # budget_balance = True, but that's ok - will be fixed in + # SS_solver if ENFORCE_SOLUTION_CHECKS and not sol.success: raise RuntimeError("Steady state equilibrium not found") From 030cd57370c025a23c5fcf779b7a10826052fee9 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Thu, 23 Oct 2025 16:59:25 -0400 Subject: [PATCH 18/23] change starting value for r --- tests/test_SS.py | 58 ++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/test_SS.py b/tests/test_SS.py index 5657fdc50..40997c416 100644 --- a/tests/test_SS.py +++ b/tests/test_SS.py @@ -1138,7 +1138,7 @@ def test_euler_equation_solver(input_tuple, ubi_j, p, expected): "initial_guess_factor_SS": 111267.90426318572, } filename4 = "run_SS_baseline_small_open_use_zeta.pkl" -param_updates5 = {"initial_guess_r_SS": 0.035} +param_updates5 = {"initial_guess_r_SS": 0.04} filename5 = "run_SS_reform.pkl" param_updates6 = { "use_zeta": True, @@ -1225,36 +1225,36 @@ def test_euler_equation_solver(input_tuple, ubi_j, p, expected): @pytest.mark.parametrize( "baseline,param_updates,filename", [ - (True, param_updates1, filename1), - (False, param_updates9, filename9), - (True, param_updates2, filename2), - (False, param_updates10, filename10), - (True, param_updates3, filename3), - # (True, param_updates4, filename4), - (False, param_updates5, filename5), - (False, param_updates6, filename6), - (False, param_updates7, filename7), - # (False, param_updates8, filename8), - (False, param_updates11, filename11), - (True, param_updates12, filename12), - (True, param_updates13, filename13), - (True, param_updates14, filename14), + # (True, param_updates1, filename1), + # (False, param_updates9, filename9), + # (True, param_updates2, filename2), + # (False, param_updates10, filename10), + # (True, param_updates3, filename3), + (True, param_updates4, filename4), + # (False, param_updates5, filename5), + # (False, param_updates6, filename6), + # (False, param_updates7, filename7), + (False, param_updates8, filename8), + # (False, param_updates11, filename11), + # (True, param_updates12, filename12), + # (True, param_updates13, filename13), + # (True, param_updates14, filename14), ], ids=[ - "Baseline", - "Reform, baseline spending", - "Baseline, use zeta", - "Reform, baseline spending, use zeta", - "Baseline, small open", - # "Baseline, small open use zeta", - "Reform", - "Reform, use zeta", - "Reform, small open", - # "Reform, small open use zeta", - "Reform, delta_tau=0", - "Baseline, non-zero Kg", - "Baseline, M=3, non-zero Kg", - "Baseline, M=3, zero Kg", + # "Baseline", + # "Reform, baseline spending", + # "Baseline, use zeta", + # "Reform, baseline spending, use zeta", + # "Baseline, small open", + "Baseline, small open use zeta", + # "Reform", + # "Reform, use zeta", + # "Reform, small open", + "Reform, small open use zeta", + # "Reform, delta_tau=0", + # "Baseline, non-zero Kg", + # "Baseline, M=3, non-zero Kg", + # "Baseline, M=3, zero Kg", ], ) @pytest.mark.local From 95152bbbf5d21e04079872ca98fda2afd3293e87 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Thu, 23 Oct 2025 16:59:53 -0400 Subject: [PATCH 19/23] add line at end --- ogcore/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ogcore/constants.py b/ogcore/constants.py index 4c2020154..7b32e624e 100644 --- a/ogcore/constants.py +++ b/ogcore/constants.py @@ -261,4 +261,4 @@ [1.25, 0.6], [0.70, 0.6], [1.30, 0.6], -] \ No newline at end of file +] From 5bb347900b2ed1d977d92c4ae9aad794d3c8a17b Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Thu, 23 Oct 2025 18:08:46 -0400 Subject: [PATCH 20/23] add test of new function --- tests/test_SS.py | 82 +++++++++++++++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 28 deletions(-) diff --git a/tests/test_SS.py b/tests/test_SS.py index 40997c416..44c82ab6e 100644 --- a/tests/test_SS.py +++ b/tests/test_SS.py @@ -1225,36 +1225,36 @@ def test_euler_equation_solver(input_tuple, ubi_j, p, expected): @pytest.mark.parametrize( "baseline,param_updates,filename", [ - # (True, param_updates1, filename1), - # (False, param_updates9, filename9), - # (True, param_updates2, filename2), - # (False, param_updates10, filename10), - # (True, param_updates3, filename3), - (True, param_updates4, filename4), - # (False, param_updates5, filename5), - # (False, param_updates6, filename6), - # (False, param_updates7, filename7), - (False, param_updates8, filename8), - # (False, param_updates11, filename11), - # (True, param_updates12, filename12), - # (True, param_updates13, filename13), - # (True, param_updates14, filename14), + (True, param_updates1, filename1), + (False, param_updates9, filename9), + (True, param_updates2, filename2), + (False, param_updates10, filename10), + (True, param_updates3, filename3), + #True, param_updates4, filename4), + (False, param_updates5, filename5), + (False, param_updates6, filename6), + (False, param_updates7, filename7), + #(False, param_updates8, filename8), + (False, param_updates11, filename11), + (True, param_updates12, filename12), + (True, param_updates13, filename13), + (True, param_updates14, filename14), ], ids=[ - # "Baseline", - # "Reform, baseline spending", - # "Baseline, use zeta", - # "Reform, baseline spending, use zeta", - # "Baseline, small open", - "Baseline, small open use zeta", - # "Reform", - # "Reform, use zeta", - # "Reform, small open", - "Reform, small open use zeta", - # "Reform, delta_tau=0", - # "Baseline, non-zero Kg", - # "Baseline, M=3, non-zero Kg", - # "Baseline, M=3, zero Kg", + "Baseline", + "Reform, baseline spending", + "Baseline, use zeta", + "Reform, baseline spending, use zeta", + "Baseline, small open", + #"Baseline, small open use zeta", + "Reform", + "Reform, use zeta", + "Reform, small open", + #"Reform, small open use zeta", + "Reform, delta_tau=0", + "Baseline, non-zero Kg", + "Baseline, M=3, non-zero Kg", + "Baseline, M=3, zero Kg", ], ) @pytest.mark.local @@ -1296,3 +1296,29 @@ def test_run_SS(tmpdir, baseline, param_updates, filename, dask_client): for k, v in expected_dict.items(): print("Checking item = ", k) assert np.allclose(test_dict[VAR_NAME_MAPPING[k]], v, atol=5e-04) + + +@pytest.mark.parametrize( + "use_zeta", [True, False], ids=["use_zeta=True", "use_zeta=False"] +) +def test_initial_guesses(tmpdir, use_zeta): + """ + Test SS.SS_initial_guesses function. Provide inputs to function and ensure + that a tuple is returned with the correct number of elements. + """ + baseline_dir = os.path.join(tmpdir, "OUTPUT_BASELINE") + p = Specifications( + output_base=baseline_dir, + baseline_dir=baseline_dir, + baseline=True, + num_workers=NUM_WORKERS, + ) + p.use_zeta = use_zeta + guesses, n_guess, b_guess = SS.SS_initial_guesses(p) + + if use_zeta: + assert len(guesses) == 7 + 1 + else: + assert len(guesses) == 7 + p.J + assert n_guess.shape == (p.S, p.J) + assert b_guess.shape == (p.S, p.J) From 66b6fd7b30665fe3cad9e33329762ecf25ae5322 Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Thu, 23 Oct 2025 18:21:26 -0400 Subject: [PATCH 21/23] add test of solve_for_j function --- tests/test_SS.py | 43 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/test_SS.py b/tests/test_SS.py index 44c82ab6e..2828d2904 100644 --- a/tests/test_SS.py +++ b/tests/test_SS.py @@ -520,6 +520,31 @@ def test_SS_solver_extra(baseline, param_updates, filename, dask_client): ) +def test_solve_for_j(): + """ + Test SS.solve_for_j function. Provide inputs to function and ensure + that solution matches what is expected. + """ + p = Specifications() + b_guess = np.ones((p.S)) * 0.07 + n_guess = np.ones((p.S)) * 0.35 * p.ltilde + guesses = np.hstack((b_guess, n_guess)) + r_p = 0.04 + w = 1.2 + p_tilde = 1.0 + bq_j = 0.0002 + rm_j = 0.005 + tr_j = 0.1 + ubi_j = 0.0 + factor = 100000 + j = 1 + test_result = SS.solve_for_j( + guesses, r_p, w, p_tilde, bq_j, rm_j, tr_j, ubi_j, factor, j, p + ) + expected_result = 0.15574086659957984 + assert np.allclose(test_result.x[4], expected_result) + + param_updates1 = {"zeta_K": [1.0]} filename1 = "inner_loop_outputs_baseline_small_open.pkl" param_updates2 = {"budget_balance": True, "alpha_G": [0.0]} @@ -1230,11 +1255,11 @@ def test_euler_equation_solver(input_tuple, ubi_j, p, expected): (True, param_updates2, filename2), (False, param_updates10, filename10), (True, param_updates3, filename3), - #True, param_updates4, filename4), + # True, param_updates4, filename4), (False, param_updates5, filename5), (False, param_updates6, filename6), (False, param_updates7, filename7), - #(False, param_updates8, filename8), + # (False, param_updates8, filename8), (False, param_updates11, filename11), (True, param_updates12, filename12), (True, param_updates13, filename13), @@ -1246,11 +1271,11 @@ def test_euler_equation_solver(input_tuple, ubi_j, p, expected): "Baseline, use zeta", "Reform, baseline spending, use zeta", "Baseline, small open", - #"Baseline, small open use zeta", + # "Baseline, small open use zeta", "Reform", "Reform, use zeta", "Reform, small open", - #"Reform, small open use zeta", + # "Reform, small open use zeta", "Reform, delta_tau=0", "Baseline, non-zero Kg", "Baseline, M=3, non-zero Kg", @@ -1308,11 +1333,11 @@ def test_initial_guesses(tmpdir, use_zeta): """ baseline_dir = os.path.join(tmpdir, "OUTPUT_BASELINE") p = Specifications( - output_base=baseline_dir, - baseline_dir=baseline_dir, - baseline=True, - num_workers=NUM_WORKERS, - ) + output_base=baseline_dir, + baseline_dir=baseline_dir, + baseline=True, + num_workers=NUM_WORKERS, + ) p.use_zeta = use_zeta guesses, n_guess, b_guess = SS.SS_initial_guesses(p) From 4d3b8b094f8881fea01db7317528ac9588849d2e Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Fri, 24 Oct 2025 11:18:58 -0400 Subject: [PATCH 22/23] remove print commands used to debug --- ogcore/SS.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ogcore/SS.py b/ogcore/SS.py index 343098a02..0809a0836 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -1382,10 +1382,6 @@ def run_SS(p, client=None): """ # For initial guesses of w, r, TR, and factor, we use values that # are close to some steady state values. - print( - "Using baseline solution as initial guess:", - p.reform_use_baseline_solution, - ) if p.baseline is False and p.reform_use_baseline_solution: # Use the baseline solution to get starting values for the reform baseline_ss_path = os.path.join(p.baseline_dir, "SS", "SS_vars.pkl") @@ -1431,8 +1427,6 @@ def run_SS(p, client=None): TR_baseline = None Ig_baseline = None BQ_items = [BQguess] if p.use_zeta else list(BQguess) - print("Baselines spending:", p.baseline_spending) - print("Use zeta:", p.use_zeta) guesses = ( [r_p_guess, rguess, wguess] + list(p_m_guess) From 60fc3ef906f13afd87241a7b9745d5b85d5d53de Mon Sep 17 00:00:00 2001 From: Jason DeBacker Date: Mon, 27 Oct 2025 13:14:08 -0400 Subject: [PATCH 23/23] update paths --- tests/test_run_example.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_run_example.py b/tests/test_run_example.py index c5bbc34dc..c957f560e 100644 --- a/tests/test_run_example.py +++ b/tests/test_run_example.py @@ -17,7 +17,7 @@ def call_run_ogcore_example(): cur_path = os.path.split(os.path.abspath(__file__))[0] path = Path(cur_path) - roe_fldr = os.path.join(path.parent, "run_examples") + roe_fldr = os.path.join(path.parent, "examples") roe_file_path = os.path.join(roe_fldr, "run_ogcore_example.py") spec = importlib.util.spec_from_file_location( "run_ogcore_example.py", roe_file_path @@ -45,7 +45,7 @@ def test_run_ogcore_example(f=call_run_ogcore_example): cur_path = os.path.split(os.path.abspath(__file__))[0] path = Path(cur_path) roe_output_dir = os.path.join( - path.parent, "run_examples", "OG-Core-Example", "OUTPUT_BASELINE" + path.parent, "examples", "OG-Core-Example", "OUTPUT_BASELINE" ) shutil.rmtree(roe_output_dir) @@ -61,21 +61,21 @@ def test_run_ogcore_example_output(f=call_run_ogcore_example): path = Path(cur_path) expected_df = pd.read_csv( os.path.join( - path.parent, "run_examples", "expected_ogcore_example_output.csv" + path.parent, "examples", "expected_ogcore_example_output.csv" ) ) # read in output from this run test_df = pd.read_csv( os.path.join( path.parent, - "run_examples", + "examples", "OG-Core-Example", "OG-Core_example_output.csv", ) ) # Delete directory created by run_ogcore_example.py roe_output_dir = os.path.join( - path.parent, "run_examples", "OG-Core-Example", "OUTPUT_BASELINE" + path.parent, "examples", "OG-Core-Example", "OUTPUT_BASELINE" ) shutil.rmtree(roe_output_dir)