Skip to content

Commit b37ad1f

Browse files
committed
override population
1 parent 5bfd950 commit b37ad1f

File tree

2 files changed

+68
-40
lines changed

2 files changed

+68
-40
lines changed

experiments/phase1/reports/calibration.qmd

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ import matplotlib.pyplot as plt
1919
import numpy as np
2020
import polars as pl
2121
import json
22-
from calibrationtools import default_particle_reader
22+
from calibrationtools import ParticleReader
2323
import tempfile
2424
import polars as pl
2525
import os
2626
from ixa_epi_covid import CovidModel
27+
from mrp.api import apply_dict_overrides
2728
2829
2930
os.chdir('../../../')
@@ -129,6 +130,10 @@ with open(default_params_file, "rb") as fp:
129130
default_params = json.load(fp)
130131
131132
default_params['epimodel.GlobalParams']['max_time'] = 200
133+
ixa_overrides = {
134+
"synth_population_file": "/mnt/S_CFA_Predict/team-CMEI/synthetic_populations/cbsa_all_work_school_household_2020-04-24/cbsa_all_work_school_household/IN/Bloomington IN.csv"
135+
}
136+
default_params = apply_dict_overrides(default_params, {'epimodel.GlobalParams': ixa_overrides})
132137
133138
mrp_defaults = {
134139
'ixa_inputs': default_params,
@@ -137,7 +142,11 @@ mrp_defaults = {
137142
"output_dir": "./experiments/phase1/calibration/output",
138143
"force_overwrite": True,
139144
},
140-
"importation_inputs": {"state": "Indiana", "year": 2020},
145+
"importation_inputs": {
146+
"state": "Indiana",
147+
"year": 2020,
148+
"symptomatic_reporting_prob": 0.5
149+
},
141150
}
142151
143152
uniq_id = 0
@@ -146,12 +155,9 @@ importation_curves = []
146155
prevalence_data = []
147156
148157
with tempfile.TemporaryDirectory() as tmpdir:
158+
reader = ParticleReader(results.priors.params, mrp_defaults)
149159
for p in particles:
150-
model_inputs = default_particle_reader(
151-
p,
152-
default_params=mrp_defaults,
153-
parameter_headers=["ixa_inputs", "epimodel.GlobalParams"],
154-
)
160+
model_inputs = reader.read_particle(p)
155161
156162
model_inputs["config_inputs"]["output_dir"] = str(
157163
Path(tmpdir, f"{uniq_id}")

scripts/phase_1_calibration.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,59 @@
1212
MultivariateNormalKernel,
1313
SeedKernel,
1414
)
15+
from mrp.api import apply_dict_overrides
1516

1617
from ixa_epi_covid import CovidModel
1718

18-
with open(Path("experiments", "phase1", "input", "priors.json"), "r") as f:
19-
priors = json.load(f)
19+
# Run-specific parameters declaration ------------------------------------------------------
20+
# Default model and parameters
21+
exe_file = Path("target", "release", "ixa-epi-covid")
22+
output_dir = Path("experiments", "phase1", "calibration", "output")
23+
default_ixa_params_file = Path(
24+
"experiments", "phase1", "input", "default_params.json"
25+
)
26+
ixa_overrides = {
27+
"synth_population_file": "/mnt/S_CFA_Predict/team-CMEI/synthetic_populations/cbsa_all_work_school_household_2020-04-24/cbsa_all_work_school_household/IN/Bloomington IN.csv"
28+
}
29+
force_overwrite = True
2030

21-
with open(
22-
Path("experiments", "phase1", "input", "default_params.json"), "r"
23-
) as f:
31+
# State importation model declaration parameters
32+
state = "Indiana"
33+
year = 2020
34+
35+
# Calibration inputs
36+
priors_file = Path("experiments", "phase1", "input", "priors.json")
37+
tolerance_values = [30.0, 20.0, 10.0, 5.0] # , 2.0, 0.01]
38+
generation_particle_count = 500
39+
target_data = 75
40+
41+
42+
# Output processing function for calibration
43+
def outputs_to_distance(model_output: pl.DataFrame, target_data: int):
44+
first_death_observed = model_output.filter(
45+
(pl.col("event") == "Dead") & (pl.col("count") > 0)
46+
).filter(pl.col("t_upper") == pl.min("t_upper"))
47+
if first_death_observed.height > 0:
48+
return abs(target_data - first_death_observed.item(0, "t_upper"))
49+
else:
50+
return 1000
51+
52+
53+
# Load environment files, defaults, and setup configurations ---------------------
54+
with open(default_ixa_params_file, "r") as f:
2455
default_params = json.load(f)
2556

57+
58+
default_params = apply_dict_overrides(
59+
default_params, {"epimodel.GlobalParams": ixa_overrides}
60+
)
61+
2662
mrp_defaults = {
2763
"ixa_inputs": default_params,
2864
"config_inputs": {
29-
"exe_file": "./target/release/ixa-epi-covid",
30-
"output_dir": "./experiments/phase1/calibration/output",
31-
"force_overwrite": True,
32-
# "ixa_overrides": {
33-
# "synth_population_file": Path(os.path.expanduser(os.getenv("SYNTH_POPULATION_DIR"))) / "in.csv"
34-
# }
65+
"exe_file": str(exe_file),
66+
"output_dir": str(output_dir),
67+
"force_overwrite": force_overwrite,
3568
},
3669
"importation_inputs": {
3770
"state": "Indiana",
@@ -49,45 +82,34 @@
4982

5083
output_dir.mkdir(parents=True, exist_ok=False)
5184

52-
P = priors
85+
# Create the model and sampler objects ------------------------------------------------
86+
with open(priors_file, "r") as f:
87+
priors = json.load(f)
88+
89+
P: dict[dict, dict] = priors
5390
K = IndependentKernels(
5491
[
55-
MultivariateNormalKernel(
56-
[p for p in P["priors"].keys()],
57-
),
92+
MultivariateNormalKernel(list(P["priors"].keys())),
5893
SeedKernel("seed"),
5994
]
6095
)
6196

6297
model = CovidModel()
6398

64-
65-
def outputs_to_distance(model_output: pl.DataFrame, target_data: int):
66-
first_death_observed = model_output.filter(
67-
(pl.col("event") == "Dead") & (pl.col("count") > 0)
68-
).filter(pl.col("t_upper") == pl.min("t_upper"))
69-
if first_death_observed.height > 0:
70-
return abs(target_data - first_death_observed.item(0, "t_upper"))
71-
else:
72-
return 1000
73-
74-
7599
sampler = ABCSampler(
76-
generation_particle_count=500,
77-
tolerance_values=[30.0, 20.0, 10.0, 5.0],
100+
generation_particle_count=generation_particle_count,
101+
tolerance_values=tolerance_values,
78102
priors=P,
79103
perturbation_kernel=K,
80104
variance_adapter=AdaptMultivariateNormalVariance(),
81105
outputs_to_distance=outputs_to_distance,
82-
target_data=75,
106+
target_data=target_data,
83107
model_runner=model,
84108
seed=123,
85109
)
86110

87-
results = sampler.run(
88-
default_params=mrp_defaults,
89-
parameter_headers=["ixa_inputs", "epimodel.GlobalParams"],
90-
)
111+
# Execute the sampler ----------------------------------------------------------------------
112+
results = sampler.run(default_params=mrp_defaults)
91113

92114
print(results)
93115

0 commit comments

Comments
 (0)