Skip to content

Commit 4ee8307

Browse files
committed
update model output return
1 parent fe6ec72 commit 4ee8307

File tree

4 files changed

+56
-60
lines changed

4 files changed

+56
-60
lines changed

experiments/phase1/reports/calibration.qmd

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ print(
5252
)
5353
```
5454

55-
Histograms of posterior samples compared to the probability density of each prior. To generate the histogram, we sample $n$ particles from the posterior population, where $n$ is the effective sample size of the population, in order to reflect the weight distribution of the posterior. For each plot, the prior is plotted as a blue line and the histogram of the posterior samples is plotted in orange with a kernel density estimator overlay.
55+
Histograms of posterior samples compared to the probability density of each prior. For each plot, the prior is plotted as a blue line and the histogram of the posterior samples is plotted in orange with a kernel density estimator overlay. To generate the histogram, we sample $n$ particles from the posterior population, where $n$ is the effective sample size of the population. Re-sampling particles in this manner reflects the posterior weight distribution, the probability mass function over accepted particles determined by the ABC-SMC algorithm perturbation kernel and prior distributions.
5656

5757
```{python}
5858
#| echo: false
5959
#| eval: false
60-
# Preferred API ----------
60+
# Preferred API for development in calibrationtools----------
6161
for param in results.fitted_params:
6262
# get values and use weights to alter histogram instead of sampling with ESS
6363
values = results.posterior_particles.get_values_of(param)
@@ -120,7 +120,7 @@ for param in results.fitted_params:
120120
```
121121
```{python}
122122
#| echo: false
123-
## Obtaining the importation time series and death incidence data frmaes for a random smaple form the posterior
123+
## Obtaining the importation time series and death incidence data frames for a random smaple form the posterior
124124
# Re-generating a random sample of parameter sets from posterior
125125
particle_count = int(min(100, results.ess))
126126
particles = results.sample_posterior_particles(n=particle_count)
@@ -129,18 +129,22 @@ default_params_file = wd / 'input' / 'default_params.json'
129129
with open(default_params_file, "rb") as fp:
130130
default_params = json.load(fp)
131131
132-
default_params['epimodel.GlobalParams']['max_time'] = 200
133132
ixa_overrides = {
134-
"synth_population_file": "/mnt/S_CFA_Predict/team-CMEI/synthetic_populations/cbsa_all_work_school_household_2020-04-24/cbsa_all_work_school_household/IN/Bloomington IN.csv"
133+
"synth_population_file": "/mnt/S_CFA_Predict/team-CMEI/synthetic_populations/cbsa_all_work_school_household_2020-04-24/cbsa_all_work_school_household/IN/Bloomington IN.csv",
134+
"imported_cases_timeseries": {
135+
"filename": "./experiments/phase1/projection/imported_cases_timeseries.csv"
136+
},
137+
"max_time": 200
135138
}
136139
default_params = apply_dict_overrides(default_params, {'epimodel.GlobalParams': ixa_overrides})
137140
138141
mrp_defaults = {
139142
'ixa_inputs': default_params,
140143
"config_inputs": {
141144
"exe_file": "./target/release/ixa-epi-covid",
142-
"output_dir": "./experiments/phase1/calibration/output",
145+
"output_dir": "./experiments/phase1/projection/output",
143146
"force_overwrite": True,
147+
"outputs_to_read": ['prevalence_report', 'imported_cases_timeseries']
144148
},
145149
"importation_inputs": {
146150
"state": "Indiana",
@@ -153,41 +157,16 @@ uniq_id = 0
153157
model = CovidModel()
154158
importation_curves = []
155159
prevalence_data = []
160+
os.makedirs(mrp_defaults["config_inputs"]["output_dir"], exist_ok=True)
161+
162+
reader = ParticleReader(results.priors.params, mrp_defaults)
163+
for p in particles:
164+
model_inputs = reader.read_particle(p)
165+
outputs = model.simulate(model_inputs)
156166
157-
with tempfile.TemporaryDirectory() as tmpdir:
158-
reader = ParticleReader(results.priors.params, mrp_defaults)
159-
for p in particles:
160-
model_inputs = reader.read_particle(p)
161-
162-
model_inputs["config_inputs"]["output_dir"] = str(
163-
Path(tmpdir, f"{uniq_id}")
164-
)
165-
os.makedirs(model_inputs["config_inputs"]["output_dir"], exist_ok=True)
166-
importation_path = Path(
167-
tmpdir, f"{uniq_id}", "importation_timeseries.csv"
168-
)
169-
170-
model_inputs["ixa_inputs"]["epimodel.GlobalParams"][
171-
"imported_cases_timeseries"
172-
]["filename"] = str(importation_path)
173-
model.simulate(model_inputs)
174-
prevalence_data.append(
175-
pl.read_csv(
176-
Path(
177-
tmpdir,
178-
f"{uniq_id}",
179-
model_inputs["ixa_inputs"]["epimodel.GlobalParams"][
180-
"prevalence_report"
181-
]["filename"],
182-
)
183-
).with_columns(pl.lit(uniq_id).alias("id"))
184-
)
185-
importation_curves.append(
186-
pl.read_csv(importation_path).with_columns(
187-
pl.lit(uniq_id).alias("id")
188-
)
189-
)
190-
uniq_id += 1
167+
importation_curves.append(outputs["imported_cases_timeseries"].with_columns(pl.lit(uniq_id).alias("id")))
168+
prevalence_data.append(outputs["prevalence_report"].with_columns(pl.lit(uniq_id).alias("id")))
169+
uniq_id += 1
191170
192171
importations = pl.concat(importation_curves)
193172
all_prevalence_data = pl.concat(prevalence_data)

scripts/phase_1_calibration.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,32 @@
2626
)
2727
ixa_overrides = {
2828
"synth_population_file": "/mnt/S_CFA_Predict/team-CMEI/synthetic_populations/cbsa_all_work_school_household_2020-04-24/cbsa_all_work_school_household/IN/Bloomington IN.csv",
29-
"first_death_terminates_run": True
29+
"first_death_terminates_run": True,
3030
}
3131
force_overwrite = False
32+
outputs_to_read = ["incidence_report"]
3233

3334
# State importation model declaration parameters
3435
state = "Indiana"
3536
year = 2020
37+
symptomatic_reporting_prob_default = 0.5
3638

3739
# Calibration inputs
3840
priors_file = Path("experiments", "phase1", "input", "priors.json")
39-
tolerance_values = [2.0, 0.1] # , 2.0, 0.01]
41+
tolerance_values = [2.0, 0.1]
4042
generation_particle_count = 500
4143
target_data = 75
4244

4345

4446
# Output processing function for calibration
45-
def outputs_to_distance(model_output: pl.DataFrame, target_data: int):
46-
first_death_observed = model_output.filter(
47-
(pl.col("event") == "Dead") & (pl.col("count") > 0)
48-
).filter(pl.col("t_upper") == pl.min("t_upper"))
47+
def outputs_to_distance(
48+
model_output: dict[str, pl.DataFrame], target_data: int
49+
):
50+
first_death_observed = (
51+
model_output["incidence_report"]
52+
.filter((pl.col("event") == "Dead") & (pl.col("count") > 0))
53+
.filter(pl.col("t_upper") == pl.min("t_upper"))
54+
)
4955
if first_death_observed.height > 0:
5056
return abs(target_data - first_death_observed.item(0, "t_upper"))
5157
else:
@@ -74,11 +80,12 @@ def outputs_to_distance(model_output: pl.DataFrame, target_data: int):
7480
"exe_file": str(exe_file),
7581
"output_dir": str(output_dir),
7682
"force_overwrite": force_overwrite,
83+
"outputs_to_read": outputs_to_read,
7784
},
7885
"importation_inputs": {
79-
"state": "Indiana",
80-
"year": 2020,
81-
"symptomatic_reporting_prob": 0.5,
86+
"state": state,
87+
"year": year,
88+
"symptomatic_reporting_prob": symptomatic_reporting_prob_default,
8289
},
8390
}
8491

src/ixa_epi_covid/covid_model.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def simulate(model_inputs: dict[str, Any]) -> pl.DataFrame:
8585
str(input_file_path),
8686
"--output",
8787
config_inputs["output_dir"],
88-
"-f",
88+
"--force-overwrite",
8989
"--no-stats",
9090
]
9191

@@ -99,11 +99,21 @@ def simulate(model_inputs: dict[str, Any]) -> pl.DataFrame:
9999
raise e
100100

101101
# Read the model incidence report from the specified location and return as a DataFrame
102-
incidence_report_filename = ixa_inputs["epimodel.GlobalParams"][
103-
"incidence_report"
104-
]["filename"]
105-
incidence_report_path = Path(
106-
config_inputs["output_dir"], incidence_report_filename
107-
)
108-
109-
return pl.read_csv(incidence_report_path)
102+
outputs = {}
103+
for output in config_inputs["outputs_to_read"]:
104+
fp = ixa_inputs["epimodel.GlobalParams"][output]["filename"]
105+
if Path(config_inputs["output_dir"], fp).exists():
106+
outputs.update(
107+
{
108+
output: pl.read_csv(
109+
Path(config_inputs["output_dir"], fp)
110+
)
111+
}
112+
)
113+
elif Path(fp).exists():
114+
outputs.update({output: pl.read_csv(Path(fp))})
115+
else:
116+
raise FileNotFoundError(
117+
f"Expected output file {fp} not found. Looked in {config_inputs['output_dir']}"
118+
)
119+
return outputs

src/model.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use ixa::{ExecutionPhase, prelude::*};
22

33
use crate::{
4-
infection_importation, infection_propagation_loop, population_loader, reports, settings,
5-
symptom_status_manager, abort_run
4+
abort_run, infection_importation, infection_propagation_loop, population_loader, reports,
5+
settings, symptom_status_manager,
66
};
77

88
pub fn initialize_model(context: &mut Context, seed: u64, max_time: f64) -> Result<(), IxaError> {

0 commit comments

Comments
 (0)