diff --git a/docs/book/content/theory/government.md b/docs/book/content/theory/government.md index b1daaa0c0..5bbe46a32 100644 --- a/docs/book/content/theory/government.md +++ b/docs/book/content/theory/government.md @@ -542,7 +542,7 @@ Total pension spending is the sum of the pension payments to each household in t (SecUBI_NonGrowthAdj)= ###### UBI specification not adjusted for economic growth - A non-growth adjusted UBI (`ubi_growthadj = False`) is one in which the initial nonstationary nominal-valued $t=0$ UBI matrix $ubi^{\$}_{j,s,t=0}$ does not grow, while the economy's long-run growth rate is $g_y$ for the most common parameterization is positive ($g_y>0$). + A non-growth adjusted UBI (`ubi_growthadj = False`) is one in which the initial nonstationary nominal-valued $t=0$ UBI matrix $ubi^{nom}_{j,s,t=0}$ does not grow, while the economy's long-run growth rate is $g_y$ for the most common parameterization is positive ($g_y>0$). ```{math} :label: EqUBIubi_nom_NonGrwAdj_jst diff --git a/ogcore/SS.py b/ogcore/SS.py index 2878dfe6b..402e52747 100644 --- a/ogcore/SS.py +++ b/ogcore/SS.py @@ -5,7 +5,7 @@ import dask.multiprocessing from ogcore import tax, pensions, household, firm, utils, fiscal from ogcore import aggregates as aggr -from ogcore.constants import SHOW_RUNTIME +from ogcore.constants import SHOW_RUNTIME, DEV_FACTOR_LIST from ogcore import config import os import warnings @@ -307,11 +307,23 @@ def inner_loop(outer_loop_vars, p, client): futures.append(f) try: - results = client.gather(futures, timeout=300) + # Wait for futures with timeout, then gather results + from distributed import wait + + done, not_done = wait(futures, timeout=600) + if not_done: + # Some futures didn't complete in time + raise TimeoutError( + f"{len(not_done)} futures did not complete within 600 seconds" + ) + results = client.gather(futures) except Exception as e: # Cancel remaining futures and fall back to serial computation - print( - f"Dask computation failed ({e}), falling back to serial computation" + import logging + + logging.warning( + f"Dask computation failed with error: {e}. " + "Falling back to serial computation." ) for f in futures: f.cancel() @@ -427,7 +439,7 @@ def inner_loop(outer_loop_vars, p, client): C_vec = np.zeros(p.I) K_demand_open_vec = np.zeros(p.M) for i_ind in range(p.I): - C_vec[i_ind] = aggr.get_C(c_i[i_ind, :, :], p, "SS") + C_vec[i_ind] = aggr.get_C(c_i[i_ind, :, :], p, "SS").item() Y_vec = np.dot(p.io_matrix.T, C_vec) for m_ind in range(p.M - 1): KYrat_m = firm.get_KY_ratio(r, p_m, p, "SS", m_ind) @@ -994,7 +1006,7 @@ def SS_solver( c_i_ss_mat[i_ind, :, :], p, "SS", - ) + ).item() ( total_tax_revenue, @@ -1305,6 +1317,56 @@ def SS_fsolve(guesses, *args): return errors +def SS_initial_guesses(p, b_val=0.0055, n_val=0.4, r_tr_scalars=[1.0, 1.0]): + """ + Finds the initial guesses for b, n and for the steady state outer + loop variables. + + Args: + p (OG-Core Specifications object): model parameters + b_val (float): initial guess value for savings + n_val (float): initial guess value for labor supply + r_tr_scalars (list): scalars to adjust initial guesses for r and TR + + Returns: + guesses (list): initial guesses for outer loop variables + b_guess (ndarray): initial guess for savings + n_guess (ndarray): initial guess for labor supply + """ + r_p_guess = r_tr_scalars[0] * p.initial_guess_r_SS + rguess = r_tr_scalars[0] * p.initial_guess_r_SS + wguess = firm.get_w_from_r(rguess, p, "SS") + p_m_guess = np.ones(p.M) + TRguess = r_tr_scalars[1] * p.initial_guess_TR_SS + Yguess = TRguess / p.alpha_T[-1] + + # create guesses list + # Note that BQ is an vector of lenght J if use_zeta=False + if p.use_zeta: + b_guess = np.ones((p.S, p.J)) * b_val + n_guess = np.ones((p.S, p.J)) * n_val * p.ltilde + BQguess = 0.12231465279007188 + else: + b_guess = ( + np.ones((p.S, p.J)) * 0.07 + ) # TODO: remove hardcode here and next line + n_guess = np.ones((p.S, p.J)) * 0.35 * p.ltilde + BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) + # append factor guess if baseline + BQ_items = [BQguess] if p.use_zeta else list(BQguess) + guesses = ( + [r_p_guess, rguess, wguess] + + list(p_m_guess) + + [Yguess] + + BQ_items + + [TRguess] + ) + if p.baseline: + guesses.append(p.initial_guess_factor_SS) + + return guesses, b_guess, n_guess + + def run_SS(p, client=None): """ Solve for steady-state equilibrium of OG-Core. @@ -1318,140 +1380,9 @@ def run_SS(p, client=None): results """ - # Create list of deviation factors for initial guesses of r and TR - dev_factor_list = [ - [1.00, 1.0], - [0.95, 1.0], - [1.05, 1.0], - [0.90, 1.0], - [1.10, 1.0], - [0.85, 1.0], - [1.15, 1.0], - [0.80, 1.0], - [1.20, 1.0], - [0.75, 1.0], - [1.25, 1.0], - [0.70, 1.0], - [1.30, 1.0], - [1.00, 0.2], - [0.95, 0.2], - [1.05, 0.2], - [0.90, 0.2], - [1.10, 0.2], - [0.85, 0.2], - [1.15, 0.2], - [0.80, 0.2], - [1.20, 0.2], - [0.75, 0.2], - [1.25, 0.2], - [0.70, 0.2], - [1.30, 0.2], - [1.00, 0.6], - [0.95, 0.6], - [1.05, 0.6], - [0.90, 0.6], - [1.10, 0.6], - [0.85, 0.6], - [1.15, 0.6], - [0.80, 0.6], - [1.20, 0.6], - [0.75, 0.6], - [1.25, 0.6], - [0.70, 0.6], - [1.30, 0.6], - ] - # For initial guesses of w, r, TR, and factor, we use values that # are close to some steady state values. - if p.baseline: - # Loop over initial guesses of r and TR until find a solution or until have - # gone through all guesses. This should usually solve in the first guess - SS_solved = False - k = 0 - while not SS_solved and k < len(dev_factor_list) - 1: - for k, v in enumerate(dev_factor_list): - logging.info( - f"SS using initial guess factors for r and TR of " - + f"{v[0]} and {v[1]} respectively." - ) - r_p_guess = v[0] * p.initial_guess_r_SS - rguess = v[0] * p.initial_guess_r_SS - if p.use_zeta: - b_guess = np.ones((p.S, p.J)) * 0.0055 - n_guess = np.ones((p.S, p.J)) * 0.4 * p.ltilde - else: - b_guess = np.ones((p.S, p.J)) * 0.07 - n_guess = np.ones((p.S, p.J)) * 0.35 * p.ltilde - wguess = firm.get_w_from_r(rguess, p, "SS") - p_m_guess = np.ones(p.M) - TRguess = v[1] * p.initial_guess_TR_SS - Yguess = TRguess / p.alpha_T[-1] - factorguess = p.initial_guess_factor_SS - BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) - ss_params_baseline = ( - b_guess, - n_guess, - None, - None, - None, - p, - client, - ) - if p.use_zeta: - BQguess = 0.12231465279007188 - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess, BQguess, TRguess, factorguess] - ) - else: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess] - + list(BQguess) - + [TRguess, factorguess] - ) - sol = opt.root( - SS_fsolve, - guesses, - args=ss_params_baseline, - method=p.SS_root_method, - tol=p.mindist_SS, - ) - if sol.success: - SS_solved = True - break - if ENFORCE_SOLUTION_CHECKS and not sol.success: - raise RuntimeError("Steady state equilibrium not found") - r_p_ss = sol.x[0] - rss = sol.x[1] - wss = sol.x[2] - p_m_ss = sol.x[3 : 3 + p.M] - Yss = sol.x[3 + p.M] - BQss = sol.x[3 + p.M + 1 : -2] - TR_ss = sol.x[-2] - factor_ss = sol.x[-1] - Yss = TR_ss / p.alpha_T[-1] # may not be right - if budget_balance - # # = True, but that's ok - will be fixed in SS_solver - fsolve_flag = True - output = SS_solver( - b_guess, - n_guess, - r_p_ss, - rss, - wss, - p_m_ss, - Yss, - BQss, - TR_ss, - None, - factor_ss, - p, - client, - fsolve_flag, - ) - else: + if p.baseline is False and p.reform_use_baseline_solution: # Use the baseline solution to get starting values for the reform baseline_ss_path = os.path.join(p.baseline_dir, "SS", "SS_vars.pkl") ss_solutions = utils.safe_read_pickle(baseline_ss_path) @@ -1475,22 +1406,53 @@ def run_SS(p, client=None): BQguess, TRguess, Yguess, - factor, + factor_ss, ) = ( ss_solutions["b_sp1"], ss_solutions["n"], float(ss_solutions["r_p"]), float(ss_solutions["r"]), float(ss_solutions["w"]), - ss_solutions[ - "p_m" - ], # Not sure why need to index p_m,but otherwise its shape is off.. + ss_solutions["p_m"], ss_solutions["BQ"], float(ss_solutions["TR"]), float(ss_solutions["Y"]), ss_solutions["factor"], ) use_new_guesses = False + if p.baseline_spending: + TR_baseline = TRguess + Ig_baseline = ss_solutions["I_g"] + else: + TR_baseline = None + Ig_baseline = None + BQ_items = [BQguess] if p.use_zeta else list(BQguess) + guesses = ( + [r_p_guess, rguess, wguess] + + list(p_m_guess) + + [Yguess] + + BQ_items + + [TRguess] + ) + # Now solve for the steady state of the reform + ss_params = ( + b_guess, + n_guess, + TR_baseline, + Ig_baseline, + factor_ss, + p, + client, + ) + + # Solve for steady state using root finder + sol = opt.root( + SS_fsolve, + guesses, + args=ss_params, + method=p.SS_root_method, + tol=p.mindist_SS, + ) else: logging.warning( "Dimensions of previous solutions for SS do not match" @@ -1501,140 +1463,115 @@ def run_SS(p, client=None): "KeyError: previous solutions for SS not found" ) use_new_guesses = True - else: - logging.info("Using new guesses for SS") - use_new_guesses = True - if use_new_guesses: - if p.use_zeta: - b_guess = np.ones((p.S, p.J)) * 0.0055 - n_guess = np.ones((p.S, p.J)) * 0.4 * p.ltilde - else: - b_guess = np.ones((p.S, p.J)) * 0.07 - n_guess = np.ones((p.S, p.J)) * 0.35 * p.ltilde - r_p_guess = p.initial_guess_r_SS - rguess = p.initial_guess_r_SS - wguess = firm.get_w_from_r(rguess, p, "SS") - p_m_guess = np.ones(p.M) - TRguess = p.initial_guess_TR_SS - Yguess = TRguess / p.alpha_T[-1] - factor = p.initial_guess_factor_SS - BQguess = aggr.get_BQ(rguess, b_guess, None, p, "SS", False) - if p.use_zeta: - BQguess = 0.12231465279007188 - if p.baseline_spending: - TR_baseline = TRguess - Ig_baseline = ss_solutions["I_g"] - ss_params_reform = ( - b_guess, - n_guess, - TR_baseline, - Ig_baseline, - factor, - p, - client, - ) - if p.use_zeta: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess, BQguess, TR_baseline] + if ( + p.baseline + or p.reform_use_baseline_solution is False + or use_new_guesses + ): + # Loop over initial guesses of r and TR until find a solution or until have + # gone through all guesses. This should usually solve in the first guess + SS_solved = False + k = 0 + while not SS_solved and k < len(DEV_FACTOR_LIST) - 1: + for k, v in enumerate(DEV_FACTOR_LIST): + logging.info( + f"SS using initial guess factors for r and TR of " + + f"{v[0]} and {v[1]} respectively." ) - else: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess] - + list(BQguess) - + [TR_baseline] + guesses, b_guess, n_guess = SS_initial_guesses( + p, r_tr_scalars=v ) - sol = opt.root( - SS_fsolve, - guesses, - args=ss_params_reform, - method=p.SS_root_method, - tol=p.mindist_SS, - ) - r_p_ss = sol.x[0] - rss = sol.x[1] - wss = sol.x[2] - p_m_ss = sol.x[3 : 3 + p.M] - Yss = sol.x[3 + p.M] - BQss = sol.x[3 + p.M + 1 : -1] - TR_ss = sol.x[-1] - else: - ss_params_reform = ( - b_guess, - n_guess, - None, - None, - factor, - p, - client, - ) - if p.use_zeta: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess, BQguess, TRguess] + ss_params = ( + b_guess, + n_guess, + None, + None, + None, + p, + client, ) - else: - guesses = ( - [r_p_guess, rguess, wguess] - + list(p_m_guess) - + [Yguess] - + list(BQguess) - + [TRguess] + if p.baseline: + factor_ss = None + else: + factor_ss = ss_solutions[ + "factor" + ] # don't guess factor, use baseline + if p.baseline_spending: + TR_baseline = ss_solutions["TR"] + Ig_baseline = ss_solutions["I_g"] + else: + TR_baseline = None + Ig_baseline = None + ss_params = ( + b_guess, + n_guess, + TR_baseline, + Ig_baseline, + factor_ss, + p, + client, ) - sol = opt.root( - SS_fsolve, - guesses, - args=ss_params_reform, - method=p.SS_root_method, - tol=p.mindist_SS, - ) - r_p_ss = sol.x[0] - rss = sol.x[1] - wss = sol.x[2] - p_m_ss = sol.x[3 : 3 + p.M] - Yss = sol.x[3 + p.M] - BQss = sol.x[3 + p.M + 1 : -1] - TR_ss = sol.x[-1] + # Solve for steady state using root finder + sol = opt.root( + SS_fsolve, + guesses, + args=ss_params, + method=p.SS_root_method, + tol=p.mindist_SS, + ) + if sol.success: + SS_solved = True + break + r_p_ss = sol.x[0] + rss = sol.x[1] + wss = sol.x[2] + p_m_ss = sol.x[3 : 3 + p.M] + if p.baseline: + BQss = sol.x[3 + p.M + 1 : -2] + TR_ss = sol.x[-2] + factor_ss = sol.x[-1] + Yss = TR_ss / p.alpha_T[-1] # may not be right - if + # budget_balance = True, but that's ok - will be fixed in + # SS_solver + else: + Yss = sol.x[3 + p.M] + BQss = sol.x[3 + p.M + 1 : -1] + TR_ss = sol.x[-1] + if not p.baseline_spending: Yss = TR_ss / p.alpha_T[-1] # may not be right - if # budget_balance = True, but that's ok - will be fixed in # SS_solver - if ( - (ENFORCE_SOLUTION_CHECKS) - and not (sol.success == 1) - and (np.absolute(np.array(sol.fun)).max() > p.mindist_SS) - ): - raise RuntimeError("Steady state equilibrium not found") - # Return SS values of variables - fsolve_flag = True - # Return SS values of variables - if not p.baseline_spending: - Ig_baseline = None - output = SS_solver( - b_guess, - n_guess, - r_p_ss, - rss, - wss, - p_m_ss, - Yss, - BQss, - TR_ss, - Ig_baseline, - factor, - p, - client, - fsolve_flag, + + if ENFORCE_SOLUTION_CHECKS and not sol.success: + raise RuntimeError("Steady state equilibrium not found") + # Trigger flag that model has been solved + fsolve_flag = True + # Return SS values of variables + if p.baseline or not p.baseline_spending: + Ig_baseline = None + output = SS_solver( + b_guess, + n_guess, + r_p_ss, + rss, + wss, + p_m_ss, + Yss, + BQss, + TR_ss, + Ig_baseline, + factor_ss, + p, + client, + fsolve_flag, + ) + if output["G"] < 0.0: + warnings.warn( + "Warning: The combination of the tax policy " + + "you specified and your target debt-to-GDP " + + "ratio results in an infeasible amount of " + + "government spending in order to close the " + + "budget (i.e., G < 0)" ) - if output["G"] < 0.0: - warnings.warn( - "Warning: The combination of the tax policy " - + "you specified and your target debt-to-GDP " - + "ratio results in an infeasible amount of " - + "government spending in order to close the " - + "budget (i.e., G < 0)" - ) + return output diff --git a/ogcore/TPI.py b/ogcore/TPI.py index 5ac4ed47e..db31f8bea 100644 --- a/ogcore/TPI.py +++ b/ogcore/TPI.py @@ -777,7 +777,38 @@ def run_TPI(p, client=None): scattered_p_future, ) futures.append(f) - results = client.gather(futures) + try: + # Wait for futures with timeout, then gather results + from distributed import wait + + done, not_done = wait(futures, timeout=600) + if not_done: + # Some futures didn't complete in time + raise TimeoutError( + f"{len(not_done)} futures did not complete within 600 seconds" + ) + results = client.gather(futures) + except Exception as e: + # Cancel remaining futures and fall back to serial computation + logging.warning( + f"Dask computation failed with error: {e}. " + "Falling back to serial computation for this iteration." + ) + for future in futures: + future.cancel() + results = [] + for j in range(p.J): + guesses = (guesses_b[:, :, j], guesses_n[:, :, j]) + res = inner_loop( + guesses, + outer_loop_vars, + initial_values, + ubi, + j, + ind, + p, + ) + results.append(res) else: # Serial fallback (no dask client) for j in range(p.J): diff --git a/ogcore/constants.py b/ogcore/constants.py index 8cffeb548..7b32e624e 100644 --- a/ogcore/constants.py +++ b/ogcore/constants.py @@ -218,3 +218,47 @@ # Ignoring the following: # 'starting_age', 'ending_age', 'constant_demographics', # 'constant_rates', 'zero_taxes' + +# List of deviation factors for initial guesses of r and TR used in +# SS.run_SS for a more robust SS solution +DEV_FACTOR_LIST = [ + [1.00, 1.0], + [0.95, 1.0], + [1.05, 1.0], + [0.90, 1.0], + [1.10, 1.0], + [0.85, 1.0], + [1.15, 1.0], + [0.80, 1.0], + [1.20, 1.0], + [0.75, 1.0], + [1.25, 1.0], + [0.70, 1.0], + [1.30, 1.0], + [1.00, 0.2], + [0.95, 0.2], + [1.05, 0.2], + [0.90, 0.2], + [1.10, 0.2], + [0.85, 0.2], + [1.15, 0.2], + [0.80, 0.2], + [1.20, 0.2], + [0.75, 0.2], + [1.25, 0.2], + [0.70, 0.2], + [1.30, 0.2], + [1.00, 0.6], + [0.95, 0.6], + [1.05, 0.6], + [0.90, 0.6], + [1.10, 0.6], + [0.85, 0.6], + [1.15, 0.6], + [0.80, 0.6], + [1.20, 0.6], + [0.75, 0.6], + [1.25, 0.6], + [0.70, 0.6], + [1.30, 0.6], +] diff --git a/ogcore/parameter_plots.py b/ogcore/parameter_plots.py index 27a4eb05f..3d66a29dd 100644 --- a/ogcore/parameter_plots.py +++ b/ogcore/parameter_plots.py @@ -49,13 +49,12 @@ def plot_imm_rates( "Source: " + source, fontsize=9, ) - plt.tight_layout(rect=(0, 0.035, 1, 1)) if include_title: plt.title("Immigration Rates") # Save or return figure if path: output_path = os.path.join(path, "imm_rates") - plt.savefig(output_path, dpi=300) + plt.savefig(output_path, bbox_inches="tight", dpi=300) plt.close() else: fig.show() @@ -407,11 +406,10 @@ def plot_fert_rates( "Source: " + source, fontsize=9, ) - plt.tight_layout(rect=(0, 0.035, 1, 1)) # Save or return figure if path: output_path = os.path.join(path, "fert_rates") - plt.savefig(output_path, dpi=300) + plt.savefig(output_path, bbox_inches="tight", dpi=300) plt.close() else: fig.show() @@ -816,7 +814,7 @@ def txfunc_graph( None """ - cmap1 = matplotlib.cm.get_cmap("summer") + cmap1 = matplotlib.colormaps.get_cmap("summer") # Make comparison plot with full income domains fig = plt.figure() diff --git a/ogcore/parameter_tables.py b/ogcore/parameter_tables.py index 2b87186a9..2a4100377 100644 --- a/ogcore/parameter_tables.py +++ b/ogcore/parameter_tables.py @@ -192,7 +192,7 @@ def param_table(p, table_format="tex", path=None): table["Symbol"].append(v[1]) table["Description"].append(v[0]) value = getattr(p, k) - if hasattr(value, "__len__") & ~isinstance(value, str): + if hasattr(value, "__len__") and not isinstance(value, str): if value.ndim > 1: report = ( "Too large to report here, see default parameters JSON" diff --git a/pytest.ini b/pytest.ini index bdc4031dd..4fdda3c2b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,9 @@ # pytest.ini [pytest] +filterwarnings = + ignore::RuntimeWarning:.*invalid value encountered.* + ignore::RuntimeWarning:.*divide by zero encountered in divide.* + ignore::RuntimeWarning:.*invalid value encountered in power.* minversion = 6.0 testpaths = ./tests diff --git a/tests/test_SS.py b/tests/test_SS.py index 5657fdc50..2828d2904 100644 --- a/tests/test_SS.py +++ b/tests/test_SS.py @@ -520,6 +520,31 @@ def test_SS_solver_extra(baseline, param_updates, filename, dask_client): ) +def test_solve_for_j(): + """ + Test SS.solve_for_j function. Provide inputs to function and ensure + that solution matches what is expected. + """ + p = Specifications() + b_guess = np.ones((p.S)) * 0.07 + n_guess = np.ones((p.S)) * 0.35 * p.ltilde + guesses = np.hstack((b_guess, n_guess)) + r_p = 0.04 + w = 1.2 + p_tilde = 1.0 + bq_j = 0.0002 + rm_j = 0.005 + tr_j = 0.1 + ubi_j = 0.0 + factor = 100000 + j = 1 + test_result = SS.solve_for_j( + guesses, r_p, w, p_tilde, bq_j, rm_j, tr_j, ubi_j, factor, j, p + ) + expected_result = 0.15574086659957984 + assert np.allclose(test_result.x[4], expected_result) + + param_updates1 = {"zeta_K": [1.0]} filename1 = "inner_loop_outputs_baseline_small_open.pkl" param_updates2 = {"budget_balance": True, "alpha_G": [0.0]} @@ -1138,7 +1163,7 @@ def test_euler_equation_solver(input_tuple, ubi_j, p, expected): "initial_guess_factor_SS": 111267.90426318572, } filename4 = "run_SS_baseline_small_open_use_zeta.pkl" -param_updates5 = {"initial_guess_r_SS": 0.035} +param_updates5 = {"initial_guess_r_SS": 0.04} filename5 = "run_SS_reform.pkl" param_updates6 = { "use_zeta": True, @@ -1230,7 +1255,7 @@ def test_euler_equation_solver(input_tuple, ubi_j, p, expected): (True, param_updates2, filename2), (False, param_updates10, filename10), (True, param_updates3, filename3), - # (True, param_updates4, filename4), + # True, param_updates4, filename4), (False, param_updates5, filename5), (False, param_updates6, filename6), (False, param_updates7, filename7), @@ -1296,3 +1321,29 @@ def test_run_SS(tmpdir, baseline, param_updates, filename, dask_client): for k, v in expected_dict.items(): print("Checking item = ", k) assert np.allclose(test_dict[VAR_NAME_MAPPING[k]], v, atol=5e-04) + + +@pytest.mark.parametrize( + "use_zeta", [True, False], ids=["use_zeta=True", "use_zeta=False"] +) +def test_initial_guesses(tmpdir, use_zeta): + """ + Test SS.SS_initial_guesses function. Provide inputs to function and ensure + that a tuple is returned with the correct number of elements. + """ + baseline_dir = os.path.join(tmpdir, "OUTPUT_BASELINE") + p = Specifications( + output_base=baseline_dir, + baseline_dir=baseline_dir, + baseline=True, + num_workers=NUM_WORKERS, + ) + p.use_zeta = use_zeta + guesses, n_guess, b_guess = SS.SS_initial_guesses(p) + + if use_zeta: + assert len(guesses) == 7 + 1 + else: + assert len(guesses) == 7 + p.J + assert n_guess.shape == (p.S, p.J) + assert b_guess.shape == (p.S, p.J) diff --git a/tests/test_TPI.py b/tests/test_TPI.py index faf9fb1c4..e04863f2c 100644 --- a/tests/test_TPI.py +++ b/tests/test_TPI.py @@ -169,10 +169,10 @@ def dask_client(): cluster = LocalCluster( n_workers=NUM_WORKERS, threads_per_worker=2, - memory_limit="2GB", - timeout="300s", - heartbeat_interval="10s", - death_timeout="60s", + memory_limit="3GB", + timeout="900s", + heartbeat_interval="30s", + death_timeout="120s", ) client = Client(cluster) yield client diff --git a/tests/test_aggregates.py b/tests/test_aggregates.py index 1fee8f612..9cfdf8cb8 100644 --- a/tests/test_aggregates.py +++ b/tests/test_aggregates.py @@ -31,7 +31,10 @@ for i in range(p.S): for k in range(p.J): L_loop[t, i, k] *= ( - p.omega[t, i] * p.lambdas[k] * n[t, i, k] * p.e[t, i, k] + p.omega[t, i].item() + * p.lambdas[k].item() + * n[t, i, k].item() + * p.e[t, i, k].item() ) expected1 = L_loop[-1, :, :].sum() expected2 = L_loop.sum(1).sum(1) diff --git a/tests/test_run_example.py b/tests/test_run_example.py index c5bbc34dc..c957f560e 100644 --- a/tests/test_run_example.py +++ b/tests/test_run_example.py @@ -17,7 +17,7 @@ def call_run_ogcore_example(): cur_path = os.path.split(os.path.abspath(__file__))[0] path = Path(cur_path) - roe_fldr = os.path.join(path.parent, "run_examples") + roe_fldr = os.path.join(path.parent, "examples") roe_file_path = os.path.join(roe_fldr, "run_ogcore_example.py") spec = importlib.util.spec_from_file_location( "run_ogcore_example.py", roe_file_path @@ -45,7 +45,7 @@ def test_run_ogcore_example(f=call_run_ogcore_example): cur_path = os.path.split(os.path.abspath(__file__))[0] path = Path(cur_path) roe_output_dir = os.path.join( - path.parent, "run_examples", "OG-Core-Example", "OUTPUT_BASELINE" + path.parent, "examples", "OG-Core-Example", "OUTPUT_BASELINE" ) shutil.rmtree(roe_output_dir) @@ -61,21 +61,21 @@ def test_run_ogcore_example_output(f=call_run_ogcore_example): path = Path(cur_path) expected_df = pd.read_csv( os.path.join( - path.parent, "run_examples", "expected_ogcore_example_output.csv" + path.parent, "examples", "expected_ogcore_example_output.csv" ) ) # read in output from this run test_df = pd.read_csv( os.path.join( path.parent, - "run_examples", + "examples", "OG-Core-Example", "OG-Core_example_output.csv", ) ) # Delete directory created by run_ogcore_example.py roe_output_dir = os.path.join( - path.parent, "run_examples", "OG-Core-Example", "OUTPUT_BASELINE" + path.parent, "examples", "OG-Core-Example", "OUTPUT_BASELINE" ) shutil.rmtree(roe_output_dir)