-
Notifications
You must be signed in to change notification settings - Fork 0
Calibration routine to first observed death #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pgv5_clinical_status
Are you sure you want to change the base?
Changes from 18 commits
71100f7
ed06e88
dad1fc0
2312e54
9a5233a
1b2dd4a
f1bbf68
2d0a7c6
6c7a5af
96b88c0
28a7000
c56240b
c30556f
1b5f0ca
c1fcf78
2d4802c
d335275
788b45c
1a8d1b5
a376732
4b594a9
5bfd950
b37ad1f
fe6ec72
4ee8307
542da1a
158208c
96737c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| **/profiling.json | ||
| experiments/**/output/ | ||
|
|
||
| # Data | ||
| *.csv | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Running the calibration routine | ||
| In order to generate the results analysis report from the `reports/` directory, first calibrate the model by using the phase 1 calibration script. Ensure that the uv environment is synced and the rust binaries have been assembled | ||
| ``` | ||
| uv sync --all-packages | ||
| uv run cargo build -r | ||
| uv run python scripts/phase_1_calibration.py | ||
| ``` | ||
|
|
||
| Then, to render the analysis report, ensure that `tinytex` is installed with | ||
|
|
||
| ``` | ||
| quarto install tinyext | ||
| ``` | ||
|
|
||
| and then render the document using | ||
|
|
||
| ``` | ||
| uv run quarto render experiments/phase1/reports/calibration.qmd | ||
| ``` | ||
|
|
||
| The resulting file should be a PDF in the reports directory. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| { | ||
| "epimodel.GlobalParams": { | ||
| "seed": 1234, | ||
| "max_time": 100.0, | ||
| "synth_population_file": "input/people_test.csv", | ||
| "symptomatic_reporting_prob": 0.5, | ||
|
||
| "initial_prevalence": 0.0, | ||
| "imported_cases_timeseries": { | ||
| "include": true, | ||
| "filename": "./experiments/phase1/calibration/output/importation_timeseries.csv" | ||
| }, | ||
| "infectiousness_rate_fn": {"Constant": { | ||
| "rate": 1.0, | ||
| "duration": 5.0 | ||
| } | ||
| }, | ||
| "probability_mild_given_infect": 0.7, | ||
| "infect_to_mild_mu": 0.1, | ||
| "infect_to_mild_sigma": 0.0, | ||
| "probability_severe_given_mild": 0.2, | ||
| "mild_to_severe_mu": 0.1, | ||
| "mild_to_severe_sigma": 0.1, | ||
| "mild_to_resolved_mu": 0.1, | ||
| "mild_to_resolved_sigma": 0.1, | ||
| "probability_critical_given_severe": 0.2, | ||
| "severe_to_critical_mu": 0.1, | ||
| "severe_to_critical_sigma": 0.1, | ||
| "severe_to_resolved_mu": 0.1, | ||
| "severe_to_resolved_sigma": 0.1, | ||
| "probability_dead_given_critical": 0.2, | ||
| "critical_to_dead_mu": 0.1, | ||
| "critical_to_dead_sigma": 0.1, | ||
| "critical_to_resolved_mu": 0.1, | ||
| "critical_to_resolved_sigma": 0.1, | ||
| "settings_properties": {"Home": {"alpha": 0.0}, | ||
| "Workplace": {"alpha": 0.0}, | ||
| "School": {"alpha": 0.0}, | ||
| "CensusTract": {"alpha": 0.0}}, | ||
| "itinerary_ratios": {"Home": 0.25, "Workplace": 0.25, "School": 0.25, "CensusTract": 0.25}, | ||
| "prevalence_report": { | ||
| "write": true, | ||
| "filename": "person_property_count.csv", | ||
| "period": 1.0 | ||
| }, | ||
| "incidence_report": { | ||
| "write": true, | ||
| "filename": "incidence_report.csv", | ||
| "period": 1.0 | ||
| }, | ||
| "transmission_report": { | ||
| "write": false, | ||
| "filename": "transmission_report.csv" | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| { | ||
| "priors": { | ||
| "symptomatic_reporting_prob": { | ||
| "distribution": "beta", | ||
| "parameters": { | ||
| "alpha": 5, | ||
| "beta": 5 | ||
| } | ||
| }, | ||
| "settings_properties.Home.alpha": { | ||
| "distribution": "uniform", | ||
| "parameters": { | ||
| "min": 0.0, | ||
| "max": 1.0 | ||
| } | ||
| }, | ||
| "probability_mild_given_infect": { | ||
| "distribution": "beta", | ||
| "parameters": { | ||
| "alpha": 7, | ||
| "beta": 3 | ||
| } | ||
| }, | ||
| "infectiousness_rate_fn.Constant.rate": { | ||
| "distribution": "uniform", | ||
| "parameters": { | ||
| "min": 0.1, | ||
| "max": 2.0 | ||
| } | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,258 @@ | ||
| --- | ||
| title: "Phase I calibration" | ||
| date: "2026-03-09" | ||
| format: pdf | ||
| --- | ||
|
|
||
| # Overview | ||
| In this phase, we aim to calibrate the model to the timing of the first reported death due to COVID-19 in the state of Indiana. The first reported death in Indiana occurred on March 16th, 2020, 10 days after the first confirmed case. | ||
|
|
||
| # Results | ||
| ```{python} | ||
| #| echo: false | ||
| import pickle | ||
| from calibrationtools.calibration_results import CalibrationResults, Particle | ||
| from pathlib import Path | ||
| import os | ||
| import seaborn as sns | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import polars as pl | ||
| import json | ||
| from calibrationtools import default_particle_reader | ||
| import tempfile | ||
| import polars as pl | ||
| import os | ||
| from ixa_epi_covid import CovidModel | ||
|
|
||
|
|
||
| os.chdir('../../../') | ||
|
|
||
| wd = Path("experiments", "phase1") | ||
|
|
||
| # Load results object from the calibration directory | ||
| with open(wd / "calibration" / "output" / "results.pkl", "rb") as fp: | ||
| results: CalibrationResults = pickle.load(fp) | ||
| ``` | ||
|
|
||
|
|
||
| Quantiles for each fitted parameter | ||
| ```{python} | ||
| #| echo: false | ||
| diagnostics = results.get_diagnostics() | ||
| print( | ||
| json.dumps( | ||
| { | ||
| k1: {k2: np.format_float_positional(v2, precision=3) for k2, v2 in v1.items()} | ||
| for k1, v1 in diagnostics["quantiles"].items() | ||
| }, | ||
| indent=4, | ||
| ) | ||
| ) | ||
| ``` | ||
|
|
||
| Histograms of posterior samples compared to the probability density of each prior. To generate the histogram, we sample $n$ particles from the posterior population, where $n$ is the effective sample size of the population, in order to reflect the weight distribution of the posterior. For each plot, the prior is plotted as a blue line and the histogram of the posterior samples is plotted in orange with a kernel density estimator overlay. | ||
|
||
|
|
||
| ```{python} | ||
| #| echo: false | ||
| posterior_samples = results.sample_posterior_particles(n=int(results.ess)) | ||
|
|
||
| for param in results.fitted_params: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These types of plots seem like they would be of interest every time the model is calibrated. Is it possible to implement a method on
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We haven't landed |
||
| vals = [p[param] for p in posterior_samples] | ||
| min_val = min(vals) | ||
| max_val = max(vals) | ||
|
|
||
| sns.histplot(x=vals, stat="density", kde=True, color='orange', edgecolor='black') | ||
| eval_points = np.arange( | ||
| min_val - np.var(vals), max_val + np.var(vals), 0.01 | ||
| ) | ||
| param_prior = None | ||
| for prior in results.priors.priors: | ||
| if prior.param == param: | ||
| param_prior = prior | ||
| break | ||
| if not param_prior: | ||
| raise (ValueError, f"Could not find prior {param}") | ||
|
|
||
| density_vals = [ | ||
| param_prior.probability_density(Particle({param: v})) | ||
| for v in eval_points | ||
| ] | ||
|
|
||
| sns.lineplot( | ||
| data=pl.DataFrame({param: list(eval_points), "density": density_vals}), | ||
| x=param, | ||
| y="density", | ||
| ) | ||
| plt.title(f"Posterior versus prior distribution") | ||
| plt.xlabel(" ".join(param.split("."))) | ||
| plt.ylabel("Density") | ||
| plt.tight_layout() | ||
| plt.show() | ||
|
|
||
| ``` | ||
| ```{python} | ||
| #| echo: false | ||
| ## Obtaining the importation time series and death incidence data frmaes for a random smaple form the posterior | ||
| # Re-generating a random sample of parameter sets from posterior | ||
| particle_count = 100 | ||
| particles = results.sample_posterior_particles(n=particle_count) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A similar point to the one above. This seems like something you will do with every calibrated model and is a substantial amount of code to produce for every model report. Is this something that can be moved into a
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something like re-run particles is valuable and could be reasonable to include on the CalibrationResults methods. I was also over-complicating what needed to be done because the cleaner approach is to modify the return method of covid_model. It now can grab whichever file reports you want in a dictionary and the for loop is easier to read |
||
| default_params_file = wd / 'input' / 'default_params.json' | ||
|
|
||
| with open(default_params_file, "rb") as fp: | ||
| default_params = json.load(fp) | ||
|
|
||
| default_params['epimodel.GlobalParams']['max_time'] = 200 | ||
|
|
||
| mrp_defaults = { | ||
| 'ixa_inputs': default_params, | ||
| "config_inputs": { | ||
| "exe_file": "./target/release/ixa-epi-covid", | ||
| "output_dir": "./experiments/phase1/calibration/output", | ||
| "force_overwrite": True, | ||
| }, | ||
| "importation_inputs": {"state": "Indiana", "year": 2020}, | ||
| } | ||
|
|
||
| uniq_id = 0 | ||
| model = CovidModel() | ||
| importation_curves = [] | ||
| prevalence_data = [] | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| for p in particles: | ||
| model_inputs = default_particle_reader( | ||
| p, | ||
| default_params=mrp_defaults, | ||
| parameter_headers=["ixa_inputs", "epimodel.GlobalParams"], | ||
| ) | ||
|
|
||
| model_inputs["config_inputs"]["output_dir"] = str( | ||
| Path(tmpdir, f"{uniq_id}") | ||
| ) | ||
| os.makedirs(model_inputs["config_inputs"]["output_dir"], exist_ok=True) | ||
| importation_path = Path( | ||
| tmpdir, f"{uniq_id}", "importation_timeseries.csv" | ||
| ) | ||
|
|
||
| model_inputs["ixa_inputs"]["epimodel.GlobalParams"][ | ||
| "imported_cases_timeseries" | ||
| ]["filename"] = str(importation_path) | ||
| model.simulate(model_inputs) | ||
| prevalence_data.append( | ||
| pl.read_csv( | ||
| Path( | ||
| tmpdir, | ||
| f"{uniq_id}", | ||
| model_inputs["ixa_inputs"]["epimodel.GlobalParams"][ | ||
| "prevalence_report" | ||
| ]["filename"], | ||
| ) | ||
| ).with_columns(pl.lit(uniq_id).alias("id")) | ||
| ) | ||
| importation_curves.append( | ||
| pl.read_csv(importation_path).with_columns( | ||
| pl.lit(uniq_id).alias("id") | ||
| ) | ||
| ) | ||
| uniq_id += 1 | ||
|
|
||
| importations = pl.concat(importation_curves) | ||
| all_prevalence_data = pl.concat(prevalence_data) | ||
| deaths = ( | ||
| all_prevalence_data | ||
| .filter(pl.col("symptom_status") == "Dead") | ||
| .group_by("t", "id") | ||
| .agg(pl.sum("count")) | ||
| ) | ||
| ``` | ||
|
|
||
| For a final SMC step with threshold toelrance above zero, we will observe some variance in the date of first reported death. Here we show a sampled histogram of the timing of the first death across the posterior simulations. The vertical dashed line indicates the observed timing of the first death in Indiana (March 16, 2020, or 75 days after the start of the simulation on January 1, 2020). We can see that while although the distribution is centered around the observed timing, there is a higher weight on earlier first reported deaths, indicating that imported infections occur faster in the simulated model than they occured in real life. This may be due to the proprotional sampling of Indiana importations from the domestic level imported infections, when in fact they may have been less probable due to mobility flow patterns. | ||
|
||
|
|
||
| ```{python} | ||
| #| echo: false | ||
|
|
||
| first_deaths = deaths.filter(pl.col("count") > 0).group_by("id").agg(pl.min("t")) | ||
| sns.histplot( | ||
| data=first_deaths, | ||
| x="t", | ||
| ) | ||
| plt.axvline(x=75, color="black", linestyle="--", label="Observed (March 16, 2020)") | ||
| plt.title(f"Distribution of first death times ({particle_count} posterior samples)") | ||
| plt.xlabel("Time of first death (days since simulation start)") | ||
| plt.ylabel("Number of simulations") | ||
| plt.legend() | ||
| plt.tight_layout() | ||
| plt.show() | ||
|
|
||
| ``` | ||
|
|
||
| We can show the posterior variance in the imported infections time series by overlaying samples from the simulated particles. Each line in the plot below corresponds to the importation time series for a single particle. | ||
| ```{python} | ||
| #| echo: false | ||
| sns.scatterplot( | ||
| data=importations.filter(pl.col('imported_infections') > 0), | ||
| x="time", | ||
| y="imported_infections", | ||
| alpha=0.05, | ||
| ) | ||
| sns.pointplot( | ||
| data=importations, | ||
| x="time", | ||
| y="imported_infections", | ||
| linestyle="none", | ||
| errorbar=None, | ||
| marker="_", | ||
| native_scale=True, | ||
| color='black' | ||
| ) | ||
| plt.title(f"Imported infections over time ({particle_count} posterior samples)") | ||
| plt.xlabel("Time (days since simulation start)") | ||
| plt.ylabel("Number of imported infections (daily)") | ||
| plt.tight_layout() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This plot needs a legend |
||
| plt.show() | ||
| ``` | ||
| We can project the cumulative number of deaths observed in the model up to the evaluation timepoint. Because the last SMC step has a distance tolerance threshold greater than zero, we see some simulations accrue deaths before the first reported death in the data. | ||
| ```{python} | ||
| #| echo: false | ||
|
|
||
| sns.lineplot( | ||
| data=deaths.filter(pl.col('t') <= 75), | ||
| x="t", | ||
| y="count", | ||
| ) | ||
| plt.title(f"Total observed deaths over time ({particle_count} posterior samples)") | ||
| plt.xlabel("Time (days since simulation start)") | ||
| plt.ylabel("Number of deaths") | ||
| plt.axvline(x=75, color="black", linestyle="--", label="First death reported (March 16, 2020)") | ||
| plt.tight_layout() | ||
| plt.show() | ||
| ``` | ||
|
|
||
| Finally, we can plot the same trajectories for number of infections observed in the model over time, which includes both imported infections and locally acquired infections. This figure shows the issue with calibratin the model with a small population size (1,000 people). The susceptible pool is rapidly depleted by local transmission, and a high proportion of the population is already recovered by the timepoint of evaluation comparing the model results to the first reported death in the data. | ||
|
|
||
| ```{python} | ||
| #| echo: false | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this plot can be rotated? It's hard to examine now at its current size. |
||
| sir_data = all_prevalence_data.group_by( | ||
| 't', 'infection_status', 'id' | ||
| ).agg( | ||
| pl.sum('count') | ||
| ) | ||
|
|
||
| sns.lineplot( | ||
| data = sir_data, | ||
| x='t', | ||
| y='count', | ||
| hue='infection_status', | ||
| units='id', | ||
| estimator=None, | ||
| alpha=0.05 | ||
| ) | ||
| plt.title(f"Individuals by infection status over time ({particle_count} posterior samples)") | ||
| plt.xlabel("Time (days since simulation start)") | ||
| plt.axvline(x=65, color="black", linestyle="--", label="First case reported (March 6, 2020)") | ||
| plt.axvline(x=75, color="red", linestyle="--", label="First death reported (March 16, 2020)") | ||
| plt.ylabel("Number of people") | ||
| plt.tight_layout() | ||
| plt.show() | ||
| ``` | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.