diff --git a/pycode/examples/simulation/ode_seir_contact_matrix_example.py b/pycode/examples/simulation/ode_seir_contact_matrix_example.py new file mode 100644 index 0000000000..5576980320 --- /dev/null +++ b/pycode/examples/simulation/ode_seir_contact_matrix_example.py @@ -0,0 +1,269 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +Example: For a given country, this example loads the age-structured contact +matrix, resolves the total population, builds a simple age-resolved ODE SEIR +model, and runs a simulation. +""" + +import io + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import requests + +from memilio.epidata.getContactData import (get_available_countries, + load_contact_matrix) +from memilio.simulation import AgeGroup, ContactMatrix, Damping +from memilio.simulation.oseir import InfectionState as State +from memilio.simulation.oseir import (Model, interpolate_simulation_result, + simulate) + +POPULATION_URL = ( + "https://raw.githubusercontent.com/kieshaprem/synthetic-contact-matrices/" + "master/generate_synthetic_matrices/input/pop/popage_total2020.csv" +) + + +def get_population_by_age(country_name: str): + """ + Loads population data for a specific country from a POPULATION_URL. + Returns the population in 5-year steps (Unit: number of people). + + Source: https://github.com/kieshaprem/synthetic-contact-matrices + Note: Data is originally in thousands, converted here to absolute numbers. + """ + + try: + # Download without saving to disk + response = requests.get(POPULATION_URL, timeout=10) + response.raise_for_status() + + # Read into Pandas + df = pd.read_csv(io.StringIO(response.text)) + + # Clean column names + df.columns = df.columns.str.strip() + + # Filter by country (case-insensitive) + country_col = "Region, subregion, country or area *" + row = df[df[country_col].str.lower() == country_name.lower()] + + if row.empty: + raise ValueError(f"Country '{country_name}' not found in data.") + + # Extract relevant columns (age0, age5, ..., age100) + age_cols = [c for c in df.columns if c.startswith('age')] + + # Extract data + pop_data = row.iloc[0][age_cols].astype(float) + + # Convert from thousands to absolute numbers + pop_data = pop_data * 1000 + + # Contact matrices end at 75+. + # The CSV goes up to 100. So, we sum everything from 75 upwards. + + # Columns up to 70 (0-4, ..., 70-74) + cols_up_to_70 = [f'age{i}' for i in range(0, 75, 5)] + + # Columns from 75 (75-79, ..., 100+) + cols_75_plus = [ + f'age{i}' for i in range(75, 105, 5) + if f'age{i}' in pop_data.index] + + # Define labels for the output dict + labels = [ + "0-4", "5-9", "10-14", "15-19", "20-24", "25-29", "30-34", "35-39", + "40-44", "45-49", "50-54", "55-59", "60-64", "65-69", "70-74", + "75+"] + + final_pop = {} + for i, col in enumerate(cols_up_to_70): + final_pop[labels[i]] = pop_data[col] + sum_75_plus = pop_data[cols_75_plus].sum() + final_pop["75+"] = sum_75_plus + + return final_pop + + except Exception as e: + raise RuntimeError(f"Error loading population data: {e}") + + +def build_country_seir_model( + contact_matrix: np.ndarray, + population_by_age: list, + transmission_probability: float = 0.06, + exposed_share: float = 1e-5, + infected_share: float = 5e-6): + """ + Build a simple age-resolved ODE SEIR model. + """ + contact_matrix = np.asarray(contact_matrix, dtype=float) + num_groups = contact_matrix.shape[0] + + model = Model(num_groups) + contacts = ContactMatrix(contact_matrix) + contacts.minimum = np.zeros_like(contact_matrix) + model.parameters.ContactPatterns.cont_freq_mat[0] = contacts + + # 60% contact reduction at t=30 + model.parameters.ContactPatterns.cont_freq_mat.add_damping(Damping( + coeffs=np.ones((num_groups, num_groups)) * 0.6, t=30.0, level=0, type=0)) + + # Use actual population distribution + # transform dict to numpy array + group_pop = np.array(list(population_by_age.values())).flatten() + + if len(group_pop) != num_groups: + # If dimensions mismatch (e.g. contact matrix has different age groups), + # we might need to adjust. For now, we assume they match (16 groups for Prem 75+). + # If not, we fallback to uniform distribution of total sum. + print( + f"Warning: Population groups ({len(group_pop)}) do not match contact matrix groups ({num_groups}). Using uniform distribution.") + total_pop = np.sum(group_pop) + group_weights = np.full(num_groups, 1.0 / num_groups) + group_pop = group_weights * total_pop + + exposed_init = np.maximum(1.0, group_pop * exposed_share) + infected_init = np.maximum(1.0, group_pop * infected_share) + + # Set parameters and initial conditions equal for all age groups. + for idx in range(num_groups): + age_group = AgeGroup(idx) + model.parameters.TimeExposed[age_group] = 5.2 + model.parameters.TimeInfected[age_group] = 6.0 + model.parameters.TransmissionProbabilityOnContact[age_group] = ( + transmission_probability) + + model.populations[age_group, State.Exposed] = exposed_init[idx] + model.populations[age_group, State.Infected] = infected_init[idx] + model.populations[age_group, State.Recovered] = 0.0 + model.populations.set_difference_from_group_total_AgeGroup( + (age_group, State.Susceptible), group_pop[idx]) + + model.check_constraints() + return model + + +def simulate_country_seir( + country: str, + days: float = 120.0, + dt: float = 0.25, + transmission_probability: float = 0.15, + exposed_share: float = 1e-5, + infected_share: float = 5e-6, + interpolate: bool = True): + """ + Load contact matrix, population data, build the ODE SEIR model, and run + a simulation. + """ + # Validate country + available = get_available_countries() + if country not in available: + raise ValueError( + f"Country '{country}' not available. " + f"Use get_available_countries() to see all {len(available)} supported countries.") + + contacts = load_contact_matrix(country, reduce_to_rki_groups=False) + population = get_population_by_age(country) + model = build_country_seir_model( + contacts.values, + population_by_age=population, + transmission_probability=transmission_probability, + exposed_share=exposed_share, + infected_share=infected_share) + + result = simulate(t0=0.0, tmax=days, dt=dt, model=model) + if interpolate: + result = interpolate_simulation_result(result) + + return result + + +def plot_results(result, country: str): + """ + Plot aggregated SEIR compartments over time. + """ + results_arr = result.as_ndarray() + times = results_arr[0, :] + + num_compartments = 4 # S, E, I, R + # get_num_elements includes all compartments and age groups. Divide to get age groups. + num_groups = result.get_num_elements() // num_compartments + + susceptible = np.zeros(len(times)) + exposed = np.zeros(len(times)) + infected = np.zeros(len(times)) + recovered = np.zeros(len(times)) + + for age_group in range(num_groups): + susceptible += results_arr[1 + age_group * 4, :] + exposed += results_arr[2 + age_group * 4, :] + infected += results_arr[3 + age_group * 4, :] + recovered += results_arr[4 + age_group * 4, :] + + # Create plot + plt.figure(figsize=(10, 6)) + plt.plot(times, susceptible, label='Susceptible', linewidth=2) + plt.plot(times, exposed, label='Exposed', linewidth=2) + plt.plot(times, infected, label='Infected', linewidth=2) + plt.plot(times, recovered, label='Recovered', linewidth=2) + + plt.xlabel('Time (days)', fontsize=12) + plt.ylabel('Population', fontsize=12) + plt.title(f'SEIR Model Simulation - {country}', + fontsize=14, fontweight='bold') + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.show() + + +def run_demo(country: str, + days: float = 120.0, + dt: float = 0.25, + transmission_probability: float = 0.15, + exposed_share: float = 1e-5, + infected_share: float = 5e-6, + plot: bool = True): + """ + Run the SEIR simulation demo for a user defined country and parameters. + """ + result = simulate_country_seir( + country, + days=days, + dt=dt, + transmission_probability=transmission_probability, + exposed_share=exposed_share, + infected_share=infected_share, + ) + print(result.get_last_value()) + + if plot: + plot_results(result, country) + + return result + + +if __name__ == "__main__": + country = "Germany" + run_demo(country=country) diff --git a/pycode/memilio-epidata/memilio/epidata/getContactData.py b/pycode/memilio-epidata/memilio/epidata/getContactData.py new file mode 100644 index 0000000000..717a3ef1be --- /dev/null +++ b/pycode/memilio-epidata/memilio/epidata/getContactData.py @@ -0,0 +1,241 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +:strong:`getContactData.py` + +Load an age-structured contact matrix for a chosen country based on +Prem et al., 2017. The module can download the supporting ZIP from +https://doi.org/10.1371/journal.pcbi.1005697.s002 (contains the +``MUestimates_all_locations_1.xlsx`` workbook) or read a defined local +workbook path. By default, downloads are done in memory and no +files are written. +""" + +import io +import os +import zipfile +from typing import Iterable, List, Optional + +import numpy as np +import pandas as pd +import requests + +CONTACT_ZIP_URL = ( + "https://journals.plos.org/ploscompbiol/article/file" + "?id=10.1371/journal.pcbi.1005697.s002&type=supplementary" +) +CONTACT_WORKBOOK_NAME = "MUestimates_all_locations_1.xlsx" + +AGE_GROUP_LABELS = [ + "0-4", + "5-9", + "10-14", + "15-19", + "20-24", + "25-29", + "30-34", + "35-39", + "40-44", + "45-49", + "50-54", + "55-59", + "60-64", + "65-69", + "70-74", + "75+", +] + +AGE_GROUP_LABELS_RKI = [ + "0-4", + "5-14", + "15-34", + "35-59", + "60-79", + "80-99", +] + + +def _normalize_country_name(country: str): + """ + Return a case-insensitive key without whitespace or punctuation. + + :param country: The country name to normalize. + :returns: Normalized country name. + """ + return "".join(ch for ch in country.casefold() if ch.isalnum()) + + +def _download_contact_workbook( + url: str = CONTACT_ZIP_URL, target_filename: str = CONTACT_WORKBOOK_NAME): + """ + Download the ZIP from the url and return the workbook. + + :param url: URL to download the ZIP from. + :param target_filename: Name of the workbook file within the ZIP. + :returns: Content of the workbook. + """ + response = requests.get(url, timeout=30) + response.raise_for_status() + with zipfile.ZipFile(io.BytesIO(response.content)) as zf: + candidates = [name for name in zf.namelist() + if name.endswith(target_filename)] + if not candidates: + raise FileNotFoundError( + f"'{target_filename}' not found in downloaded workbook.") + with zf.open(candidates[0]) as f: + return f.read() + + +def _load_workbook_bytes( + contact_path: Optional[str], + url: str = CONTACT_ZIP_URL, + target_filename: str = CONTACT_WORKBOOK_NAME): + """ + Return workbook either from a user path or by downloading the ZIP. + + :param contact_path: Optional local path to the workbook. + :param url: Url to download the ZIP from if no path is provided. + :param target_filename: Name of the workbook file within the ZIP. + :returns: Content of the workbook. + """ + if contact_path: + if not os.path.exists(contact_path): + raise FileNotFoundError( + f"Contact matrix file not found at {contact_path}") + with open(contact_path, "rb") as f: + return f.read() + return _download_contact_workbook(url=url, target_filename=target_filename) + + +def list_available_contact_countries( + contact_path: Optional[str] = None): + """ + List all country names available in the contact matrix workbook. + + :param contact_path: Optional local path to the workbook. + :returns: List of all country names. + """ + xls_bytes = _load_workbook_bytes(contact_path) + xls = pd.ExcelFile(io.BytesIO(xls_bytes)) + return xls.sheet_names + + +def get_available_countries(contact_path: Optional[str] = None): + """ + Get list of all available countries. + + :param contact_path: Optional local path to the workbook. + :returns: List of all available countries. + """ + return list_available_contact_countries(contact_path) + + +def _select_sheet_name(country: str, sheet_names: Iterable[str]): + """ + Select the appropriate sheet name from the workbook for a given country. + + :param country: Country name as listed in the workbook (case-insensitive). + :param sheet_names: Iterable of available sheet names. + :returns: The exact sheet name as found in the workbook. + """ + lookup = {_normalize_country_name(name): name for name in sheet_names} + key = _normalize_country_name(country) + if key not in lookup: + available = ", ".join(sheet_names) + raise ValueError( + f"Country '{country}' not found in contact matrices. " + f"Available sheets: {available}") + return lookup[key] + + +def load_contact_matrix( + country: str, + contact_path: Optional[str] = None, + reduce_to_rki_groups: bool = True): + """ + Load the all-locations contact matrix for the given country. If + ``contact_path`` is not provided, the function downloads the + ``MUestimates_all_locations_1.xlsx`` workbook from Prem et al., 2017. + + :param country: Country name as listed in the workbook (case-insensitive). + :param contact_path: Optional path to ``MUestimates_all_locations_1.xlsx``. + :param reduce_to_rki_groups: If True, aggregate to the six RKI age groups + (0-4, 5-14, 15-34, 35-59, 60-79, 80+ years). Default True. + :returns: DataFrame indexed by age group with floats. + """ + xls_bytes = _load_workbook_bytes(contact_path) + xls = pd.ExcelFile(io.BytesIO(xls_bytes)) + sheet_names = xls.sheet_names + sheet = _select_sheet_name(country, sheet_names) + df = pd.read_excel(xls, sheet_name=sheet, engine="openpyxl") + + # Ensure numeric values and trim potential trailing rows/cols. + matrix = df.apply(pd.to_numeric, errors="coerce") + matrix = matrix.iloc[:len(AGE_GROUP_LABELS), :len(AGE_GROUP_LABELS)] + matrix.columns = AGE_GROUP_LABELS[:matrix.shape[1]] + matrix.index = AGE_GROUP_LABELS[:matrix.shape[0]] + + if matrix.isnull().any().any(): + raise ValueError( + f"Contact matrix for '{country}' contains non-numeric entries.") + + if matrix.shape[0] != matrix.shape[1]: + raise ValueError( + f"Contact matrix for '{country}' is not square: {matrix.shape}") + + if reduce_to_rki_groups: + matrix = _aggregate_to_rki_age_groups(matrix) + + return matrix + + +def _aggregate_to_rki_age_groups(matrix: pd.DataFrame): + """ + Aggregate a 16x16 age contact matrix to the 6-group RKI scheme. + + Assumes the original columns/rows follow AGE_GROUP_LABELS order. + Note: The source only provides data up to 70-74 and a 75+ group. + We map 60-74 to the 60-79 RKI group and 75+ to the 80-99 RKI group. + + :param matrix: The 16x16 contact matrix to aggregate. + :returns: Aggregated 6x6 contact matrix. + """ + if matrix.shape != (len(AGE_GROUP_LABELS), len(AGE_GROUP_LABELS)): + raise ValueError( + f"Expected a {len(AGE_GROUP_LABELS)}x{len(AGE_GROUP_LABELS)} matrix for aggregation.") + + groups = [ + [0], # 0-4 + [1, 2], # 5-14 + [3, 4, 5, 6], # 15-34 + [7, 8, 9, 10, 11], # 35-59 + [12, 13, 14], # 60-79 + [15], # 80-99 (source has 75+ only) + ] + + aggregated = pd.DataFrame( + index=AGE_GROUP_LABELS_RKI, columns=AGE_GROUP_LABELS_RKI, dtype=float) + + for i, rows in enumerate(groups): + for j, cols in enumerate(groups): + block = matrix.values[np.ix_(rows, cols)] + aggregated.iat[i, j] = float(block.mean()) + + return aggregated diff --git a/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py b/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py new file mode 100644 index 0000000000..05ae728ff4 --- /dev/null +++ b/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py @@ -0,0 +1,162 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import io +import os +import tempfile +import unittest +import zipfile +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd + +from memilio.epidata.getContactData import (AGE_GROUP_LABELS, + AGE_GROUP_LABELS_RKI, + CONTACT_WORKBOOK_NAME, + list_available_contact_countries, + load_contact_matrix) + + +class TestGetContactData(unittest.TestCase): + """Tests for loading contact matrices.""" + + def setUp(self): + """Track temporary files to clean up after each test.""" + self._temp_files = [] + + def tearDown(self): + for path in self._temp_files: + try: + os.remove(path) + except FileNotFoundError: + pass + + def _create_workbook(self, sheets: dict): + """Create a temporary xlsx file with provided sheet_name.""" + fd, path = tempfile.mkstemp(suffix=".xlsx") + os.close(fd) + with pd.ExcelWriter(path, engine="openpyxl") as writer: + for sheet_name, df in sheets.items(): + df.to_excel(writer, sheet_name=sheet_name, index=False) + self._temp_files.append(path) + return path + + def _create_zip_with_workbook(self, sheets: dict): + """Create a ZIP in memory that contains MUestimates_all_locations_1.xlsx.""" + with io.BytesIO() as buf_zip: + with zipfile.ZipFile(buf_zip, mode="w") as zf: + with io.BytesIO() as buf_xlsx: + with pd.ExcelWriter(buf_xlsx, engine="openpyxl") as writer: + for sheet_name, df in sheets.items(): + df.to_excel( + writer, sheet_name=sheet_name, index=False) + zf.writestr(CONTACT_WORKBOOK_NAME, buf_xlsx.getvalue()) + return buf_zip.getvalue() + + def test_list_available_explicit_path(self): + """Ensure listing works with a custom workbook path.""" + data = pd.DataFrame(np.ones((16, 16))) + contact_path = self._create_workbook({"Germany": data, "Spain": data}) + + countries = list_available_contact_countries(contact_path=contact_path) + self.assertEqual(set(countries), {"Germany", "Spain"}) + + @patch('memilio.epidata.getContactData.requests.get') + def test_list_available_downloads_zip(self, mock_get): + """Listing without a path downloads the ZIP and reads workbook.""" + data = pd.DataFrame(np.ones((16, 16))) + zip_bytes = self._create_zip_with_workbook( + {"Germany": data, "Spain": data}) + mock_resp = Mock() + mock_resp.content = zip_bytes + mock_resp.raise_for_status = Mock() + mock_get.return_value = mock_resp + + countries = list_available_contact_countries(contact_path=None) + self.assertEqual(set(countries), {"Germany", "Spain"}) + mock_resp.raise_for_status.assert_called_once() + mock_get.assert_called_once() + + def test_load_contact_matrix(self): + """Contact matrix loads with full 16x16 matrix.""" + matrix_values = np.arange(16 * 16).reshape(16, 16) + df = pd.DataFrame(matrix_values) + contact_path = self._create_workbook({"Germany": df}) + + matrix = load_contact_matrix( + "Germany", contact_path=contact_path, + reduce_to_rki_groups=False) + + self.assertEqual(matrix.shape, (16, 16)) + self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS) + self.assertEqual(list(matrix.index), AGE_GROUP_LABELS) + self.assertEqual(matrix.iloc[0, 0], 0) + self.assertEqual(matrix.iloc[-1, -1], 255) + + def test_load_contact_matrix_rki_groups(self): + """Default is to aggregate to the 6-group RKI age groups.""" + matrix_values = np.arange(16 * 16).reshape(16, 16) + df = pd.DataFrame(matrix_values) + contact_path = self._create_workbook({"Germany": df}) + + matrix = load_contact_matrix("Germany", contact_path=contact_path) + + self.assertEqual(matrix.shape, (6, 6)) + self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS_RKI) + self.assertEqual(list(matrix.index), AGE_GROUP_LABELS_RKI) + # spot-check a few block means + self.assertEqual(matrix.iloc[0, 0], 0) # block with only (0,0) + # rows 1,2 and col 0 -> values 16 and 32 => mean 24 + self.assertEqual(matrix.iloc[1, 0], 24) + # rows 3,4,5,6 and cols 3,4,5,6 (15-34 vs 15-34) + sub = matrix_values[3:7, 3:7].mean() + self.assertEqual(matrix.iloc[2, 2], sub) + + @patch('memilio.epidata.getContactData.requests.get') + def test_load_contact_matrix_downloads_no_path(self, mock_get): + """When no path is given the workbook is downloaded from the DOI ZIP.""" + matrix_values = np.arange(16 * 16).reshape(16, 16) + df = pd.DataFrame(matrix_values) + zip_bytes = self._create_zip_with_workbook({"Germany": df}) + mock_resp = Mock() + mock_resp.content = zip_bytes + mock_resp.raise_for_status = Mock() + mock_get.return_value = mock_resp + + matrix = load_contact_matrix("Germany", contact_path=None) + self.assertEqual(matrix.shape, (6, 6)) + self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS_RKI) + self.assertEqual(list(matrix.index), AGE_GROUP_LABELS_RKI) + self.assertEqual(matrix.iloc[0, 0], 0) + self.assertEqual(matrix.iloc[1, 0], 24) + mock_resp.raise_for_status.assert_called_once() + mock_get.assert_called_once() + + def test_load_contact_matrix_unknown_country_raises(self): + """Loading a non-existing sheet raises ValueError.""" + data = pd.DataFrame(np.zeros((16, 16))) + contact_path = self._create_workbook({"Germany": data}) + + with self.assertRaises(ValueError): + load_contact_matrix("France", contact_path=contact_path) + + +if __name__ == '__main__': + unittest.main()