Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/estimation/gridsearch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -10,7 +11,9 @@
warnings.filterwarnings("ignore", message="Polyfit may be poorly conditioned")


def run_1d_gridsearch(func, params, loc, gridspec, n_seeds, n_cores):
def run_1d_gridsearch(
func, params, loc, gridspec, n_seeds, n_cores, initial_states_path
):
"""Run a grid search over one parameter."""
seeds = _get_seeds(n_seeds)
grid = np.linspace(*gridspec)
Expand All @@ -19,7 +22,18 @@ def run_1d_gridsearch(func, params, loc, gridspec, n_seeds, n_cores):
for point, seed in itertools.product(grid, seeds):
p = params.copy(deep=True)
p.loc[loc, "value"] = point
arguments.append({"params": p, "seed": seed})
if initial_states_path is not None:
# go back from seed to seed index
path = Path(initial_states_path.format(seed=int((seed - 500) / 100_000)))
else:
path = None
arguments.append(
{
"params": p,
"seed": seed,
"initial_states_path": path,
}
)

results = joblib_batch_evaluator(
func=func,
Expand Down Expand Up @@ -68,6 +82,7 @@ def run_2d_gridsearch(
n_cores,
mask=None,
names=("x_1", "x_2"),
initial_states_path=None,
):
"""Run a grid search over two parameters."""
# naming: _x refers to loc1, _y to loc2 and z to function values
Expand All @@ -93,7 +108,20 @@ def run_2d_gridsearch(
p = params.copy(deep=True)
p.loc[loc1, "value"] = x
p.loc[loc2, "value"] = y
arguments.append({"params": p, "seed": seed})

if initial_states_path is not None:
# go back from seed to seed index
path = Path(initial_states_path.format(seed=int(seed / 100_000 - 500)))
else:
path = None

arguments.append(
{
"params": p,
"seed": seed,
"initial_states_path": path,
}
)

results = joblib_batch_evaluator(
func=func,
Expand Down
21 changes: 20 additions & 1 deletion src/estimation/msm_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_parallelizable_msm_criterion(
spring_end_date,
mode,
debug,
initial_states_path=None,
):
"""Get a parallelizable msm criterion function."""
pmsm = functools.partial(
Expand All @@ -37,6 +38,7 @@ def get_parallelizable_msm_criterion(
spring_end_date=spring_end_date,
mode=mode,
debug=debug,
initial_states_path=initial_states_path,
)
return pmsm

Expand Down Expand Up @@ -67,6 +69,7 @@ def _build_and_evaluate_msm_func(
spring_end_date,
mode,
debug,
initial_states_path,
):
""" """
params_hash = hash_array(params["value"].to_numpy())
Expand All @@ -79,10 +82,11 @@ def _build_and_evaluate_msm_func(
start_date=fall_start_date,
end_date=fall_end_date,
debug=debug,
initial_states_path=initial_states_path,
)
res_fall["share_known_cases"].to_pickle(share_known_path)

if mode in ["spring", "combined"]:
if mode == "spring":
res_spring = _build_and_evaluate_msm_func_one_season(
params=params,
seed=seed + 84587,
Expand All @@ -91,7 +95,20 @@ def _build_and_evaluate_msm_func(
end_date=spring_end_date,
debug=debug,
group_share_known_case_path=share_known_path,
initial_states_path=initial_states_path,
)
elif mode == "combined":
res_spring = _build_and_evaluate_msm_func_one_season(
params=params,
seed=seed + 84587,
prefix=prefix,
start_date=spring_start_date,
end_date=spring_end_date,
debug=debug,
group_share_known_case_path=share_known_path,
initial_states_path=None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if None is correct here. @janosg How do you make sure here that the simulation is continued?

)

if mode == "fall":
res = res_fall
elif mode == "spring":
Expand Down Expand Up @@ -140,6 +157,7 @@ def _build_and_evaluate_msm_func_one_season(
start_date,
end_date,
debug,
initial_states_path,
group_share_known_case_path=None,
):
"""Build and evaluate a msm criterion function.
Expand All @@ -155,6 +173,7 @@ def _build_and_evaluate_msm_func_one_season(
group_share_known_case_path=group_share_known_case_path,
debug=debug,
return_last_states=False,
initial_states_path=initial_states_path,
)
params_hash = hash_array(params["value"].to_numpy())
path = BLD / "exploration" / f"{prefix}_{params_hash}_{os.getpid()}"
Expand Down
14 changes: 12 additions & 2 deletions tests/test_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,39 @@
from src.estimation.gridsearch import run_2d_gridsearch


def func1(params, seed, initial_states_path): # noqa: U100
return {"value": params["value"] @ params["value"]}


def func2(params, seed, initial_states_path): # noqa: U100
return {"value": (params.loc[0, "value"] - 0.1) ** 2}


def test_2d_gridsearch():
_, grid, best_index, _ = run_2d_gridsearch(
func=lambda params, seed: {"value": params["value"] @ params["value"]},
func=func1,
params=pd.DataFrame([0, 0], columns=["value"]),
loc1=[0],
gridspec1=(-1, 1, 21),
loc2=[1],
gridspec2=(-3, 3, 21),
n_seeds=1,
n_cores=1,
initial_states_path=None,
)

assert_array_almost_equal(grid[best_index], np.zeros(2))


def test_1d_gridsearch():
_, grid, best_index, _ = run_1d_gridsearch(
func=lambda params, seed: {"value": (params.loc[0, "value"] - 0.1) ** 2},
func=func2,
params=pd.DataFrame([0], columns=["value"]),
loc=[0],
gridspec=(-1, 1, 21),
n_seeds=1,
n_cores=1,
initial_states_path=None,
)

assert np.allclose(grid[best_index], 0.1)