Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions pycode/examples/simulation/ode_seir_contact_matrix_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
#############################################################################
# Copyright (C) 2020-2025 MEmilio
#
# Authors: Henrik Zunker
#
# Contact: Martin J. Kuehn <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#############################################################################
"""
Example: For a given country, this example loads the age-structured contact
matrix, resolves the total population, builds a simple age-resolved ODE SEIR
model, and runs a simulation.
"""

import io

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests

from memilio.epidata.getContactData import (get_available_countries,
load_contact_matrix)
from memilio.simulation import AgeGroup, ContactMatrix, Damping
from memilio.simulation.oseir import InfectionState as State
from memilio.simulation.oseir import (Model, interpolate_simulation_result,
simulate)

POPULATION_URL = (
"https://raw.githubusercontent.com/kieshaprem/synthetic-contact-matrices/"
"master/generate_synthetic_matrices/input/pop/popage_total2020.csv"
)


def get_population_by_age(country_name: str):
"""
Loads population data for a specific country from a POPULATION_URL.
Returns the population in 5-year steps (Unit: number of people).

Source: https://github.com/kieshaprem/synthetic-contact-matrices
Note: Data is originally in thousands, converted here to absolute numbers.
"""

try:
# Download without saving to disk
response = requests.get(POPULATION_URL, timeout=10)
response.raise_for_status()

# Read into Pandas
df = pd.read_csv(io.StringIO(response.text))

# Clean column names
df.columns = df.columns.str.strip()

# Filter by country (case-insensitive)
country_col = "Region, subregion, country or area *"
row = df[df[country_col].str.lower() == country_name.lower()]

if row.empty:
raise ValueError(f"Country '{country_name}' not found in data.")

# Extract relevant columns (age0, age5, ..., age100)
age_cols = [c for c in df.columns if c.startswith('age')]

# Extract data
pop_data = row.iloc[0][age_cols].astype(float)

# Convert from thousands to absolute numbers
pop_data = pop_data * 1000

# Contact matrices end at 75+.
# The CSV goes up to 100. So, we sum everything from 75 upwards.

# Columns up to 70 (0-4, ..., 70-74)
cols_up_to_70 = [f'age{i}' for i in range(0, 75, 5)]

# Columns from 75 (75-79, ..., 100+)
cols_75_plus = [
f'age{i}' for i in range(75, 105, 5)
if f'age{i}' in pop_data.index]

# Define labels for the output dict
labels = [
"0-4", "5-9", "10-14", "15-19", "20-24", "25-29", "30-34", "35-39",
"40-44", "45-49", "50-54", "55-59", "60-64", "65-69", "70-74",
"75+"]

final_pop = {}
for i, col in enumerate(cols_up_to_70):
final_pop[labels[i]] = pop_data[col]
sum_75_plus = pop_data[cols_75_plus].sum()
final_pop["75+"] = sum_75_plus

return final_pop

except Exception as e:
raise RuntimeError(f"Error loading population data: {e}")


def build_country_seir_model(
contact_matrix: np.ndarray,
population_by_age: list,
transmission_probability: float = 0.06,
exposed_share: float = 1e-5,
infected_share: float = 5e-6):
"""
Build a simple age-resolved ODE SEIR model.
"""
contact_matrix = np.asarray(contact_matrix, dtype=float)
num_groups = contact_matrix.shape[0]

model = Model(num_groups)
contacts = ContactMatrix(contact_matrix)
contacts.minimum = np.zeros_like(contact_matrix)
model.parameters.ContactPatterns.cont_freq_mat[0] = contacts

# 60% contact reduction at t=30
model.parameters.ContactPatterns.cont_freq_mat.add_damping(Damping(
coeffs=np.ones((num_groups, num_groups)) * 0.6, t=30.0, level=0, type=0))

# Use actual population distribution
# transform dict to numpy array
group_pop = np.array(list(population_by_age.values())).flatten()

if len(group_pop) != num_groups:
# If dimensions mismatch (e.g. contact matrix has different age groups),
# we might need to adjust. For now, we assume they match (16 groups for Prem 75+).
# If not, we fallback to uniform distribution of total sum.
print(
f"Warning: Population groups ({len(group_pop)}) do not match contact matrix groups ({num_groups}). Using uniform distribution.")
total_pop = np.sum(group_pop)
group_weights = np.full(num_groups, 1.0 / num_groups)
group_pop = group_weights * total_pop

exposed_init = np.maximum(1.0, group_pop * exposed_share)
infected_init = np.maximum(1.0, group_pop * infected_share)

# Set parameters and initial conditions equal for all age groups.
for idx in range(num_groups):
age_group = AgeGroup(idx)
model.parameters.TimeExposed[age_group] = 5.2
model.parameters.TimeInfected[age_group] = 6.0
model.parameters.TransmissionProbabilityOnContact[age_group] = (
transmission_probability)

model.populations[age_group, State.Exposed] = exposed_init[idx]
model.populations[age_group, State.Infected] = infected_init[idx]
model.populations[age_group, State.Recovered] = 0.0
model.populations.set_difference_from_group_total_AgeGroup(
(age_group, State.Susceptible), group_pop[idx])

model.check_constraints()
return model


def simulate_country_seir(
country: str,
days: float = 120.0,
dt: float = 0.25,
transmission_probability: float = 0.15,
exposed_share: float = 1e-5,
infected_share: float = 5e-6,
interpolate: bool = True):
"""
Load contact matrix, population data, build the ODE SEIR model, and run
a simulation.
"""
# Validate country
available = get_available_countries()
if country not in available:
raise ValueError(
f"Country '{country}' not available. "
f"Use get_available_countries() to see all {len(available)} supported countries.")

contacts = load_contact_matrix(country, reduce_to_rki_groups=False)
population = get_population_by_age(country)
model = build_country_seir_model(
contacts.values,
population_by_age=population,
transmission_probability=transmission_probability,
exposed_share=exposed_share,
infected_share=infected_share)

result = simulate(t0=0.0, tmax=days, dt=dt, model=model)
if interpolate:
result = interpolate_simulation_result(result)

return result


def plot_results(result, country: str):
"""
Plot aggregated SEIR compartments over time.
"""
results_arr = result.as_ndarray()
times = results_arr[0, :]

num_compartments = 4 # S, E, I, R
# get_num_elements includes all compartments and age groups. Divide to get age groups.
num_groups = result.get_num_elements() // num_compartments

susceptible = np.zeros(len(times))
exposed = np.zeros(len(times))
infected = np.zeros(len(times))
recovered = np.zeros(len(times))

for age_group in range(num_groups):
susceptible += results_arr[1 + age_group * 4, :]
exposed += results_arr[2 + age_group * 4, :]
infected += results_arr[3 + age_group * 4, :]
recovered += results_arr[4 + age_group * 4, :]

# Create plot
plt.figure(figsize=(10, 6))
plt.plot(times, susceptible, label='Susceptible', linewidth=2)
plt.plot(times, exposed, label='Exposed', linewidth=2)
plt.plot(times, infected, label='Infected', linewidth=2)
plt.plot(times, recovered, label='Recovered', linewidth=2)

plt.xlabel('Time (days)', fontsize=12)
plt.ylabel('Population', fontsize=12)
plt.title(f'SEIR Model Simulation - {country}',
fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


def run_demo(country: str,
days: float = 120.0,
dt: float = 0.25,
transmission_probability: float = 0.15,
exposed_share: float = 1e-5,
infected_share: float = 5e-6,
plot: bool = True):
"""
Run the SEIR simulation demo for a user defined country and parameters.
"""
result = simulate_country_seir(
country,
days=days,
dt=dt,
transmission_probability=transmission_probability,
exposed_share=exposed_share,
infected_share=infected_share,
)
print(result.get_last_value())

if plot:
plot_results(result, country)

return result


if __name__ == "__main__":
country = "Germany"
run_demo(country=country)
Loading