From b8ef548ef59e5d877ebae637a0c7ac1f5af8ff7a Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Tue, 2 Dec 2025 23:10:03 +0100 Subject: [PATCH 01/11] generic contact data and tests --- .../memilio/epidata/getContactData.py | 155 ++++++++++++++++++ .../test_epidata_get_contact_data.py | 140 ++++++++++++++++ 2 files changed, 295 insertions(+) create mode 100644 pycode/memilio-epidata/memilio/epidata/getContactData.py create mode 100644 pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py diff --git a/pycode/memilio-epidata/memilio/epidata/getContactData.py b/pycode/memilio-epidata/memilio/epidata/getContactData.py new file mode 100644 index 0000000000..e72192da2e --- /dev/null +++ b/pycode/memilio-epidata/memilio/epidata/getContactData.py @@ -0,0 +1,155 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +:strong:`getContactData.py` + +Load an age-structured contact matrix for a chosen country based on +Prem et al., 2017. The module can download the supporting ZIP from +https://doi.org/10.1371/journal.pcbi.1005697.s002 (contains the +``MUestimates_all_locations_1.xlsx`` workbook) or read a defined local +workbook path. By default, downloads are done in memory and no +files are written. +""" + +import io +import os +import zipfile +from typing import Iterable, List, Optional + +import pandas as pd +import requests + +CONTACT_ZIP_URL = ( + "https://journals.plos.org/ploscompbiol/article/file" + "?id=10.1371/journal.pcbi.1005697.s002&type=supplementary" +) +CONTACT_WORKBOOK_NAME = "MUestimates_all_locations_1.xlsx" + +# Only kept for tests or explicit overrides; regular calls should pass +# contact_path=None to trigger download. +DEFAULT_CONTACT_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", + CONTACT_WORKBOOK_NAME)) + +AGE_GROUP_LABELS = [ + "0-4", + "5-9", + "10-14", + "15-19", + "20-24", + "25-29", + "30-34", + "35-39", + "40-44", + "45-49", + "50-54", + "55-59", + "60-64", + "65-69", + "70-74", + "75+", +] + + +def _normalize_country_name(country: str) -> str: + """Return a case-insensitive key without whitespace or punctuation.""" + return "".join(ch for ch in country.casefold() if ch.isalnum()) + + +def _download_contact_workbook( + url: str = CONTACT_ZIP_URL, target_filename: str = CONTACT_WORKBOOK_NAME) -> bytes: + """Download the ZIP from the DOI and return the workbook bytes in memory.""" + response = requests.get(url, timeout=30) + response.raise_for_status() + with zipfile.ZipFile(io.BytesIO(response.content)) as zf: + candidates = [name for name in zf.namelist() + if name.endswith(target_filename)] + if not candidates: + raise FileNotFoundError( + f"'{target_filename}' not found in downloaded archive.") + with zf.open(candidates[0]) as f: + return f.read() + + +def _load_workbook_bytes( + contact_path: Optional[str], + url: str = CONTACT_ZIP_URL, + target_filename: str = CONTACT_WORKBOOK_NAME) -> bytes: + """Return workbook bytes either from a user path or by downloading the ZIP.""" + if contact_path: + if not os.path.exists(contact_path): + raise FileNotFoundError( + f"Contact matrix file not found at {contact_path}") + with open(contact_path, "rb") as f: + return f.read() + return _download_contact_workbook(url=url, target_filename=target_filename) + + +def list_available_contact_countries( + contact_path: Optional[str] = None) -> List[str]: + """List all country names available in the contact matrix workbook.""" + xls_bytes = _load_workbook_bytes(contact_path) + xls = pd.ExcelFile(io.BytesIO(xls_bytes)) + return xls.sheet_names + + +def _select_sheet_name(country: str, sheet_names: Iterable[str]) -> str: + lookup = {_normalize_country_name(name): name for name in sheet_names} + key = _normalize_country_name(country) + if key not in lookup: + available = ", ".join(sheet_names) + raise ValueError( + f"Country '{country}' not found in contact matrices. " + f"Available sheets: {available}") + return lookup[key] + + +def load_contact_matrix( + country: str, contact_path: Optional[str] = None) -> pd.DataFrame: + """ + Load the all-locations contact matrix for the given country. If + ``contact_path`` is not provided, the function downloads the + ``MUestimates_all_locations_1.xlsx`` workbook from the DOI ZIP. + + :param country: Country name as listed in the workbook (case-insensitive). + :param contact_path: Optional path to ``MUestimates_all_locations_1.xlsx``. + :returns: DataFrame indexed by age group with floats. + """ + xls_bytes = _load_workbook_bytes(contact_path) + xls = pd.ExcelFile(io.BytesIO(xls_bytes)) + sheet_names = xls.sheet_names + sheet = _select_sheet_name(country, sheet_names) + df = pd.read_excel(xls, sheet_name=sheet, engine="openpyxl") + + # Ensure numeric values and trim potential trailing rows/cols. + matrix = df.apply(pd.to_numeric, errors="coerce") + matrix = matrix.iloc[:len(AGE_GROUP_LABELS), :len(AGE_GROUP_LABELS)] + matrix.columns = AGE_GROUP_LABELS[:matrix.shape[1]] + matrix.index = AGE_GROUP_LABELS[:matrix.shape[0]] + + if matrix.isnull().any().any(): + raise ValueError( + f"Contact matrix for '{country}' contains non-numeric entries.") + + if matrix.shape[0] != matrix.shape[1]: + raise ValueError( + f"Contact matrix for '{country}' is not square: {matrix.shape}") + + return matrix diff --git a/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py b/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py new file mode 100644 index 0000000000..c439455b94 --- /dev/null +++ b/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py @@ -0,0 +1,140 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +import io +import os +import tempfile +import unittest +import zipfile +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd + +from memilio.epidata.getContactData import (AGE_GROUP_LABELS, + CONTACT_WORKBOOK_NAME, + list_available_contact_countries, + load_contact_matrix) + + +class TestGetContactData(unittest.TestCase): + """Tests for loading contact matrices.""" + + def setUp(self): + """Track temporary files to clean up after each test.""" + self._temp_files = [] + + def tearDown(self): + for path in self._temp_files: + try: + os.remove(path) + except FileNotFoundError: + pass + + def _create_workbook(self, sheets: dict) -> str: + """Create a temporary xlsx file with provided sheet_name -> DataFrame.""" + fd, path = tempfile.mkstemp(suffix=".xlsx") + os.close(fd) + with pd.ExcelWriter(path, engine="openpyxl") as writer: + for sheet_name, df in sheets.items(): + df.to_excel(writer, sheet_name=sheet_name, index=False) + self._temp_files.append(path) + return path + + def _create_zip_with_workbook(self, sheets: dict) -> bytes: + """Create a ZIP in memory that contains MUestimates_all_locations_1.xlsx.""" + with io.BytesIO() as buf_zip: + with zipfile.ZipFile(buf_zip, mode="w") as zf: + with io.BytesIO() as buf_xlsx: + with pd.ExcelWriter(buf_xlsx, engine="openpyxl") as writer: + for sheet_name, df in sheets.items(): + df.to_excel( + writer, sheet_name=sheet_name, index=False) + zf.writestr(CONTACT_WORKBOOK_NAME, buf_xlsx.getvalue()) + return buf_zip.getvalue() + + def test_list_available_contact_countries_uses_explicit_path(self): + """Ensure listing works with a custom workbook path.""" + data = pd.DataFrame(np.ones((16, 16))) + contact_path = self._create_workbook({"Germany": data, "Spain": data}) + + countries = list_available_contact_countries(contact_path=contact_path) + self.assertEqual(set(countries), {"Germany", "Spain"}) + + @patch('memilio.epidata.getContactData.requests.get') + def test_list_available_contact_countries_downloads_zip(self, mock_get): + """Listing without a path downloads the ZIP and reads workbook bytes.""" + data = pd.DataFrame(np.ones((16, 16))) + zip_bytes = self._create_zip_with_workbook( + {"Germany": data, "Spain": data}) + mock_resp = Mock() + mock_resp.content = zip_bytes + mock_resp.raise_for_status = Mock() + mock_get.return_value = mock_resp + + countries = list_available_contact_countries(contact_path=None) + self.assertEqual(set(countries), {"Germany", "Spain"}) + mock_resp.raise_for_status.assert_called_once() + mock_get.assert_called_once() + + def test_load_contact_matrix_reads_numeric_and_labels(self): + """Contact matrix is loaded from the given path with expected shape/labels.""" + matrix_values = np.arange(16 * 16).reshape(16, 16) + df = pd.DataFrame(matrix_values) + contact_path = self._create_workbook({"Germany": df}) + + matrix = load_contact_matrix("Germany", contact_path=contact_path) + + self.assertEqual(matrix.shape, (16, 16)) + self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS) + self.assertEqual(list(matrix.index), AGE_GROUP_LABELS) + self.assertEqual(matrix.iloc[0, 0], 0) + self.assertEqual(matrix.iloc[-1, -1], 255) + + @patch('memilio.epidata.getContactData.requests.get') + def test_load_contact_matrix_downloads_when_no_path(self, mock_get): + """When no path is given the workbook is downloaded from the DOI ZIP.""" + matrix_values = np.arange(16 * 16).reshape(16, 16) + df = pd.DataFrame(matrix_values) + zip_bytes = self._create_zip_with_workbook({"Germany": df}) + mock_resp = Mock() + mock_resp.content = zip_bytes + mock_resp.raise_for_status = Mock() + mock_get.return_value = mock_resp + + matrix = load_contact_matrix("Germany", contact_path=None) + self.assertEqual(matrix.shape, (16, 16)) + self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS) + self.assertEqual(list(matrix.index), AGE_GROUP_LABELS) + self.assertEqual(matrix.iloc[0, 0], 0) + self.assertEqual(matrix.iloc[-1, -1], 255) + mock_resp.raise_for_status.assert_called_once() + mock_get.assert_called_once() + + def test_load_contact_matrix_unknown_country_raises(self): + """Loading a non-existing sheet raises ValueError.""" + data = pd.DataFrame(np.zeros((16, 16))) + contact_path = self._create_workbook({"Germany": data}) + + with self.assertRaises(ValueError): + load_contact_matrix("France", contact_path=contact_path) + + +if __name__ == '__main__': + unittest.main() From 4ef23f837b84b3880d72c7928a1435eecbb37381 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Tue, 2 Dec 2025 23:54:37 +0100 Subject: [PATCH 02/11] example --- .../ode_seir_contact_matrix_example.py | 193 ++++++++++++++++++ .../memilio/epidata/getContactData.py | 74 +++++-- .../test_epidata_get_contact_data.py | 50 +++-- 3 files changed, 287 insertions(+), 30 deletions(-) create mode 100644 pycode/examples/simulation/ode_seir_contact_matrix_example.py diff --git a/pycode/examples/simulation/ode_seir_contact_matrix_example.py b/pycode/examples/simulation/ode_seir_contact_matrix_example.py new file mode 100644 index 0000000000..65253866fa --- /dev/null +++ b/pycode/examples/simulation/ode_seir_contact_matrix_example.py @@ -0,0 +1,193 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +""" +Example: For a given country, this example loads the age-structured contact +matrix, resolves the total population, builds a simple age-resolved ODE SEIR +model, and runs a simulation. +""" + +import io + +import numpy as np +import pandas as pd +import requests + +from memilio.epidata.getContactData import (load_contact_matrix) +from memilio.simulation import AgeGroup, ContactMatrix +from memilio.simulation.oseir import InfectionState as State +from memilio.simulation.oseir import (Model, interpolate_simulation_result, + simulate) + +OWID_POPULATION_URL = ( + "https://raw.githubusercontent.com/owid/covid-19-data/master/" + "scripts/input/un/population_latest.csv" +) + + +def _normalize_country_name(country: str): + """Return a case-insensitive key without whitespace or punctuation.""" + return "".join(ch for ch in country.casefold() if ch.isalnum()) + + +def _download_population_table(url: str = OWID_POPULATION_URL): + response = requests.get(url, timeout=10) + response.raise_for_status() + return pd.read_csv(io.StringIO(response.text)) + + +def _pick_population_from_table(df: pd.DataFrame, + country: str): + # Identify candidate population columns provided by the OWID/UN table. + pop_cols = [c + for c + in ["population", "Population", "pop", "PopTotal", "Value"] + if c in df.columns] + if not pop_cols: + return None + + # Try multiple possible country-name columns to find the first match. + country_cols = [ + c + for c + in + ["entity", "Entity", "location", "Location", "country", "Country", + "country_name", "Country Name"] if c in df.columns] + key = _normalize_country_name(country) + for c_col in country_cols: + normalized = df[c_col].astype(str).map(_normalize_country_name) + matches = df.loc[normalized == key] + if matches.empty: + continue + # If multiple years are present, prefer the latest year. + matches = matches.sort_values("year", ascending=False) + pop = matches.iloc[0][pop_cols[0]] + if pd.notna(pop): + return int(pop) + return None + + +def get_population_total(country: str): + """ + get a countrys total population using online data. + """ + try: + df_pop = _download_population_table() + pop_online = _pick_population_from_table( + df_pop, country) + if pop_online: + return pop_online + except Exception: + pass + raise RuntimeError( + "No population found. Provide population_override or extend population_fallback.") + + +def build_country_seir_model( + contact_matrix: np.ndarray, + population_total: int, + transmission_probability: float = 0.06, + exposed_share: float = 1e-5, + infected_share: float = 5e-6): + """ + Build a simple age-resolved ODE SEIR model. + """ + contact_matrix = np.asarray(contact_matrix, dtype=float) + num_groups = contact_matrix.shape[0] + + model = Model(num_groups) + contacts = ContactMatrix(contact_matrix) + contacts.minimum = np.zeros_like(contact_matrix) + model.parameters.ContactPatterns.cont_freq_mat[0] = contacts + + # Distribute population and initial exposed/infected uniformly across age groups. + group_weights = np.full(num_groups, 1.0 / num_groups) + group_pop = group_weights * population_total + exposed_init = np.maximum(1.0, group_pop * exposed_share) + infected_init = np.maximum(1.0, group_pop * infected_share) + + # Set parameters and initial conditions equal for all age groups. + for idx in range(num_groups): + age_group = AgeGroup(idx) + model.parameters.TimeExposed[age_group] = 5.2 + model.parameters.TimeInfected[age_group] = 6.0 + model.parameters.TransmissionProbabilityOnContact[age_group] = ( + transmission_probability) + + model.populations[age_group, State.Exposed] = exposed_init[idx] + model.populations[age_group, State.Infected] = infected_init[idx] + model.populations[age_group, State.Recovered] = 0.0 + model.populations.set_difference_from_group_total_AgeGroup( + (age_group, State.Susceptible), group_pop[idx]) + + model.check_constraints() + return model + + +def simulate_country_seir( + country: str, + days: float = 120.0, + dt: float = 0.25, + transmission_probability: float = 0.06, + exposed_share: float = 1e-5, + infected_share: float = 5e-6, + interpolate: bool = True): + """ + Load contact matrix, fetch population, build the ODE SEIR model, and run + a simulation. Returns (result, contacts_df, population_int). + """ + contacts = load_contact_matrix(country) + population = get_population_total(country) + model = build_country_seir_model( + contacts.values, + population_total=population, + transmission_probability=transmission_probability, + exposed_share=exposed_share, + infected_share=infected_share) + + result = simulate(t0=0.0, tmax=days, dt=dt, model=model) + if interpolate: + result = interpolate_simulation_result(result) + + return result + + +def run_demo(country: str, + days: float = 120.0, + dt: float = 0.25, + transmission_probability: float = 0.06, + exposed_share: float = 1e-5, + infected_share: float = 5e-6): + """ + Run the SEIR simulation demo for a user defined country and parameters. + """ + result = simulate_country_seir( + country, + days=days, + dt=dt, + transmission_probability=transmission_probability, + exposed_share=exposed_share, + infected_share=infected_share, + ) + print(result.get_last_value()) + + +if __name__ == "__main__": + country = "China" + run_demo(country=country) diff --git a/pycode/memilio-epidata/memilio/epidata/getContactData.py b/pycode/memilio-epidata/memilio/epidata/getContactData.py index e72192da2e..ac749b81a6 100644 --- a/pycode/memilio-epidata/memilio/epidata/getContactData.py +++ b/pycode/memilio-epidata/memilio/epidata/getContactData.py @@ -33,6 +33,7 @@ import zipfile from typing import Iterable, List, Optional +import numpy as np import pandas as pd import requests @@ -42,12 +43,6 @@ ) CONTACT_WORKBOOK_NAME = "MUestimates_all_locations_1.xlsx" -# Only kept for tests or explicit overrides; regular calls should pass -# contact_path=None to trigger download. -DEFAULT_CONTACT_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", - CONTACT_WORKBOOK_NAME)) - AGE_GROUP_LABELS = [ "0-4", "5-9", @@ -67,15 +62,24 @@ "75+", ] +AGE_GROUP_LABELS_RKI = [ + "0-4", + "5-14", + "15-34", + "35-59", + "60-79", + "80-99", +] + -def _normalize_country_name(country: str) -> str: +def _normalize_country_name(country: str): """Return a case-insensitive key without whitespace or punctuation.""" return "".join(ch for ch in country.casefold() if ch.isalnum()) def _download_contact_workbook( - url: str = CONTACT_ZIP_URL, target_filename: str = CONTACT_WORKBOOK_NAME) -> bytes: - """Download the ZIP from the DOI and return the workbook bytes in memory.""" + url: str = CONTACT_ZIP_URL, target_filename: str = CONTACT_WORKBOOK_NAME): + """Download the ZIP from the DOI and return the workbook.""" response = requests.get(url, timeout=30) response.raise_for_status() with zipfile.ZipFile(io.BytesIO(response.content)) as zf: @@ -83,7 +87,7 @@ def _download_contact_workbook( if name.endswith(target_filename)] if not candidates: raise FileNotFoundError( - f"'{target_filename}' not found in downloaded archive.") + f"'{target_filename}' not found in downloaded workbook.") with zf.open(candidates[0]) as f: return f.read() @@ -91,8 +95,8 @@ def _download_contact_workbook( def _load_workbook_bytes( contact_path: Optional[str], url: str = CONTACT_ZIP_URL, - target_filename: str = CONTACT_WORKBOOK_NAME) -> bytes: - """Return workbook bytes either from a user path or by downloading the ZIP.""" + target_filename: str = CONTACT_WORKBOOK_NAME): + """Return workbook either from a user path or by downloading the ZIP.""" if contact_path: if not os.path.exists(contact_path): raise FileNotFoundError( @@ -103,14 +107,14 @@ def _load_workbook_bytes( def list_available_contact_countries( - contact_path: Optional[str] = None) -> List[str]: + contact_path: Optional[str] = None): """List all country names available in the contact matrix workbook.""" xls_bytes = _load_workbook_bytes(contact_path) xls = pd.ExcelFile(io.BytesIO(xls_bytes)) return xls.sheet_names -def _select_sheet_name(country: str, sheet_names: Iterable[str]) -> str: +def _select_sheet_name(country: str, sheet_names: Iterable[str]): lookup = {_normalize_country_name(name): name for name in sheet_names} key = _normalize_country_name(country) if key not in lookup: @@ -122,14 +126,18 @@ def _select_sheet_name(country: str, sheet_names: Iterable[str]) -> str: def load_contact_matrix( - country: str, contact_path: Optional[str] = None) -> pd.DataFrame: + country: str, + contact_path: Optional[str] = None, + reduce_to_rki_groups: bool = True): """ Load the all-locations contact matrix for the given country. If ``contact_path`` is not provided, the function downloads the - ``MUestimates_all_locations_1.xlsx`` workbook from the DOI ZIP. + ``MUestimates_all_locations_1.xlsx`` workbook from Prem et al., 2017. :param country: Country name as listed in the workbook (case-insensitive). :param contact_path: Optional path to ``MUestimates_all_locations_1.xlsx``. + :param reduce_to_rki_groups: If True, aggregate to the 6 RKI age groups + (0-4, 5-14, 15-34, 35-59, 60-79, 80-99). Default True. :returns: DataFrame indexed by age group with floats. """ xls_bytes = _load_workbook_bytes(contact_path) @@ -152,4 +160,38 @@ def load_contact_matrix( raise ValueError( f"Contact matrix for '{country}' is not square: {matrix.shape}") + if reduce_to_rki_groups: + matrix = _aggregate_to_rki_age_groups(matrix) + return matrix + + +def _aggregate_to_rki_age_groups(matrix: pd.DataFrame): + """ + Aggregate a 16x16 age contact matrix to the 6-group RKI scheme. + + Assumes the original columns/rows follow AGE_GROUP_LABELS order. + Note: The source only provides a 75+ group; we map it entirely to 80-99. + """ + if matrix.shape != (len(AGE_GROUP_LABELS), len(AGE_GROUP_LABELS)): + raise ValueError( + f"Expected a {len(AGE_GROUP_LABELS)}x{len(AGE_GROUP_LABELS)} matrix for aggregation.") + + groups = [ + [0], # 0-4 + [1, 2], # 5-14 + [3, 4, 5, 6], # 15-34 + [7, 8, 9, 10, 11], # 35-59 + [12, 13, 14], # 60-79 + [15], # 80-99 (source has 75+ only) + ] + + aggregated = pd.DataFrame( + index=AGE_GROUP_LABELS_RKI, columns=AGE_GROUP_LABELS_RKI, dtype=float) + + for i, rows in enumerate(groups): + for j, cols in enumerate(groups): + block = matrix.values[np.ix_(rows, cols)] + aggregated.iat[i, j] = float(block.mean()) + + return aggregated diff --git a/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py b/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py index c439455b94..05ae728ff4 100644 --- a/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py +++ b/pycode/memilio-epidata/memilio/epidata_test/test_epidata_get_contact_data.py @@ -28,6 +28,7 @@ import pandas as pd from memilio.epidata.getContactData import (AGE_GROUP_LABELS, + AGE_GROUP_LABELS_RKI, CONTACT_WORKBOOK_NAME, list_available_contact_countries, load_contact_matrix) @@ -47,8 +48,8 @@ def tearDown(self): except FileNotFoundError: pass - def _create_workbook(self, sheets: dict) -> str: - """Create a temporary xlsx file with provided sheet_name -> DataFrame.""" + def _create_workbook(self, sheets: dict): + """Create a temporary xlsx file with provided sheet_name.""" fd, path = tempfile.mkstemp(suffix=".xlsx") os.close(fd) with pd.ExcelWriter(path, engine="openpyxl") as writer: @@ -57,7 +58,7 @@ def _create_workbook(self, sheets: dict) -> str: self._temp_files.append(path) return path - def _create_zip_with_workbook(self, sheets: dict) -> bytes: + def _create_zip_with_workbook(self, sheets: dict): """Create a ZIP in memory that contains MUestimates_all_locations_1.xlsx.""" with io.BytesIO() as buf_zip: with zipfile.ZipFile(buf_zip, mode="w") as zf: @@ -69,7 +70,7 @@ def _create_zip_with_workbook(self, sheets: dict) -> bytes: zf.writestr(CONTACT_WORKBOOK_NAME, buf_xlsx.getvalue()) return buf_zip.getvalue() - def test_list_available_contact_countries_uses_explicit_path(self): + def test_list_available_explicit_path(self): """Ensure listing works with a custom workbook path.""" data = pd.DataFrame(np.ones((16, 16))) contact_path = self._create_workbook({"Germany": data, "Spain": data}) @@ -78,8 +79,8 @@ def test_list_available_contact_countries_uses_explicit_path(self): self.assertEqual(set(countries), {"Germany", "Spain"}) @patch('memilio.epidata.getContactData.requests.get') - def test_list_available_contact_countries_downloads_zip(self, mock_get): - """Listing without a path downloads the ZIP and reads workbook bytes.""" + def test_list_available_downloads_zip(self, mock_get): + """Listing without a path downloads the ZIP and reads workbook.""" data = pd.DataFrame(np.ones((16, 16))) zip_bytes = self._create_zip_with_workbook( {"Germany": data, "Spain": data}) @@ -93,13 +94,15 @@ def test_list_available_contact_countries_downloads_zip(self, mock_get): mock_resp.raise_for_status.assert_called_once() mock_get.assert_called_once() - def test_load_contact_matrix_reads_numeric_and_labels(self): - """Contact matrix is loaded from the given path with expected shape/labels.""" + def test_load_contact_matrix(self): + """Contact matrix loads with full 16x16 matrix.""" matrix_values = np.arange(16 * 16).reshape(16, 16) df = pd.DataFrame(matrix_values) contact_path = self._create_workbook({"Germany": df}) - matrix = load_contact_matrix("Germany", contact_path=contact_path) + matrix = load_contact_matrix( + "Germany", contact_path=contact_path, + reduce_to_rki_groups=False) self.assertEqual(matrix.shape, (16, 16)) self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS) @@ -107,8 +110,27 @@ def test_load_contact_matrix_reads_numeric_and_labels(self): self.assertEqual(matrix.iloc[0, 0], 0) self.assertEqual(matrix.iloc[-1, -1], 255) + def test_load_contact_matrix_rki_groups(self): + """Default is to aggregate to the 6-group RKI age groups.""" + matrix_values = np.arange(16 * 16).reshape(16, 16) + df = pd.DataFrame(matrix_values) + contact_path = self._create_workbook({"Germany": df}) + + matrix = load_contact_matrix("Germany", contact_path=contact_path) + + self.assertEqual(matrix.shape, (6, 6)) + self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS_RKI) + self.assertEqual(list(matrix.index), AGE_GROUP_LABELS_RKI) + # spot-check a few block means + self.assertEqual(matrix.iloc[0, 0], 0) # block with only (0,0) + # rows 1,2 and col 0 -> values 16 and 32 => mean 24 + self.assertEqual(matrix.iloc[1, 0], 24) + # rows 3,4,5,6 and cols 3,4,5,6 (15-34 vs 15-34) + sub = matrix_values[3:7, 3:7].mean() + self.assertEqual(matrix.iloc[2, 2], sub) + @patch('memilio.epidata.getContactData.requests.get') - def test_load_contact_matrix_downloads_when_no_path(self, mock_get): + def test_load_contact_matrix_downloads_no_path(self, mock_get): """When no path is given the workbook is downloaded from the DOI ZIP.""" matrix_values = np.arange(16 * 16).reshape(16, 16) df = pd.DataFrame(matrix_values) @@ -119,11 +141,11 @@ def test_load_contact_matrix_downloads_when_no_path(self, mock_get): mock_get.return_value = mock_resp matrix = load_contact_matrix("Germany", contact_path=None) - self.assertEqual(matrix.shape, (16, 16)) - self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS) - self.assertEqual(list(matrix.index), AGE_GROUP_LABELS) + self.assertEqual(matrix.shape, (6, 6)) + self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS_RKI) + self.assertEqual(list(matrix.index), AGE_GROUP_LABELS_RKI) self.assertEqual(matrix.iloc[0, 0], 0) - self.assertEqual(matrix.iloc[-1, -1], 255) + self.assertEqual(matrix.iloc[1, 0], 24) mock_resp.raise_for_status.assert_called_once() mock_get.assert_called_once() From fff87607aee01f374d7bf1f7368f69d656fdcf11 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 3 Dec 2025 11:11:54 +0100 Subject: [PATCH 03/11] [ci skip] update doc --- pycode/examples/simulation/ode_seir_contact_matrix_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pycode/examples/simulation/ode_seir_contact_matrix_example.py b/pycode/examples/simulation/ode_seir_contact_matrix_example.py index 65253866fa..5f848b66a6 100644 --- a/pycode/examples/simulation/ode_seir_contact_matrix_example.py +++ b/pycode/examples/simulation/ode_seir_contact_matrix_example.py @@ -150,7 +150,7 @@ def simulate_country_seir( interpolate: bool = True): """ Load contact matrix, fetch population, build the ODE SEIR model, and run - a simulation. Returns (result, contacts_df, population_int). + a simulation. """ contacts = load_contact_matrix(country) population = get_population_total(country) @@ -189,5 +189,5 @@ def run_demo(country: str, if __name__ == "__main__": - country = "China" + country = "Germany" run_demo(country=country) From d8faf08294972e60ef14f21c585088667952a87a Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 17 Dec 2025 10:57:28 +0100 Subject: [PATCH 04/11] age resolved pop data --- .../ode_seir_contact_matrix_example.py | 138 ++++++++++-------- 1 file changed, 75 insertions(+), 63 deletions(-) diff --git a/pycode/examples/simulation/ode_seir_contact_matrix_example.py b/pycode/examples/simulation/ode_seir_contact_matrix_example.py index 5f848b66a6..85c5f95997 100644 --- a/pycode/examples/simulation/ode_seir_contact_matrix_example.py +++ b/pycode/examples/simulation/ode_seir_contact_matrix_example.py @@ -35,73 +35,75 @@ from memilio.simulation.oseir import (Model, interpolate_simulation_result, simulate) -OWID_POPULATION_URL = ( - "https://raw.githubusercontent.com/owid/covid-19-data/master/" - "scripts/input/un/population_latest.csv" +POPULATION_URL = ( + "https://raw.githubusercontent.com/kieshaprem/synthetic-contact-matrices/" + "master/generate_synthetic_matrices/input/pop/popage_total2020.csv" ) -def _normalize_country_name(country: str): - """Return a case-insensitive key without whitespace or punctuation.""" - return "".join(ch for ch in country.casefold() if ch.isalnum()) - - -def _download_population_table(url: str = OWID_POPULATION_URL): - response = requests.get(url, timeout=10) - response.raise_for_status() - return pd.read_csv(io.StringIO(response.text)) - - -def _pick_population_from_table(df: pd.DataFrame, - country: str): - # Identify candidate population columns provided by the OWID/UN table. - pop_cols = [c - for c - in ["population", "Population", "pop", "PopTotal", "Value"] - if c in df.columns] - if not pop_cols: - return None - - # Try multiple possible country-name columns to find the first match. - country_cols = [ - c - for c - in - ["entity", "Entity", "location", "Location", "country", "Country", - "country_name", "Country Name"] if c in df.columns] - key = _normalize_country_name(country) - for c_col in country_cols: - normalized = df[c_col].astype(str).map(_normalize_country_name) - matches = df.loc[normalized == key] - if matches.empty: - continue - # If multiple years are present, prefer the latest year. - matches = matches.sort_values("year", ascending=False) - pop = matches.iloc[0][pop_cols[0]] - if pd.notna(pop): - return int(pop) - return None - - -def get_population_total(country: str): +def get_population_by_age(country_name: str): """ - get a countrys total population using online data. + Loads population data for a specific country from a GitHub source (based on UN data). + Returns the population in 5-year steps (Unit: number of people). + + Source: https://github.com/kieshaprem/synthetic-contact-matrices + Note: Data is originally in thousands, converted here to absolute numbers. """ + try: - df_pop = _download_population_table() - pop_online = _pick_population_from_table( - df_pop, country) - if pop_online: - return pop_online - except Exception: - pass - raise RuntimeError( - "No population found. Provide population_override or extend population_fallback.") + # download the file and save only in variable + response = requests.get(POPULATION_URL, timeout=10) + response.raise_for_status() + df = pd.read_csv(io.StringIO(response.text)) + + # clean column names + df.columns = df.columns.str.strip() + + # Filter by country (case-insensitive) + country_col = "Region, subregion, country or area *" + row = df[df[country_col].str.lower() == country_name.lower()] + + if row.empty: + raise ValueError(f"Country '{country_name}' not found in data.") + + # Extract relevant columns + age_cols = [c for c in df.columns if c.startswith('age')] + pop_data = row.iloc[0][age_cols].astype(float) + + # Convert from thousands to absolute numbers + pop_data = pop_data * 1000 + + # --- Aggregation for Memilio / Prem Matrices (usually up to 75+) --- + # Prem matrices in memilio often end at 75+. + # The CSV goes up to 100. We sum everything from 75 upwards. + + # Columns up to 70 (0-4, ..., 70-74) + cols_up_to_70 = [f'age{i}' for i in range(0, 75, 5)] + + # Columns from 75 (75-79, ..., 100+) + cols_75_plus = [ + f'age{i}' for i in range(75, 105, 5) + if f'age{i}' in pop_data.index] + + final_pop = [] + + # 1. Keep groups up to 70-74 + for col in cols_up_to_70: + final_pop.append(pop_data[col]) + + # 2. Sum groups from 75+ + sum_75_plus = pop_data[cols_75_plus].sum() + final_pop.append(sum_75_plus) + + return final_pop + + except Exception as e: + raise RuntimeError(f"Error loading population data: {e}") def build_country_seir_model( contact_matrix: np.ndarray, - population_total: int, + population_by_age: list, transmission_probability: float = 0.06, exposed_share: float = 1e-5, infected_share: float = 5e-6): @@ -116,9 +118,19 @@ def build_country_seir_model( contacts.minimum = np.zeros_like(contact_matrix) model.parameters.ContactPatterns.cont_freq_mat[0] = contacts - # Distribute population and initial exposed/infected uniformly across age groups. - group_weights = np.full(num_groups, 1.0 / num_groups) - group_pop = group_weights * population_total + # Use actual population distribution + group_pop = np.array(population_by_age) + + if len(group_pop) != num_groups: + # If dimensions mismatch (e.g. contact matrix has different age groups), + # we might need to adjust. For now, we assume they match (16 groups for Prem 75+). + # If not, we fallback to uniform distribution of total sum (not ideal but safe). + print( + f"Warning: Population groups ({len(group_pop)}) do not match contact matrix groups ({num_groups}). Using uniform distribution.") + total_pop = np.sum(group_pop) + group_weights = np.full(num_groups, 1.0 / num_groups) + group_pop = group_weights * total_pop + exposed_init = np.maximum(1.0, group_pop * exposed_share) infected_init = np.maximum(1.0, group_pop * infected_share) @@ -152,11 +164,11 @@ def simulate_country_seir( Load contact matrix, fetch population, build the ODE SEIR model, and run a simulation. """ - contacts = load_contact_matrix(country) - population = get_population_total(country) + contacts = load_contact_matrix(country, reduce_to_rki_groups=False) + population = get_population_by_age(country) model = build_country_seir_model( contacts.values, - population_total=population, + population_by_age=population, transmission_probability=transmission_probability, exposed_share=exposed_share, infected_share=infected_share) From e3c700966d85025c7a7f79a90cd5cdcf3c97ffb1 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 17 Dec 2025 11:08:16 +0100 Subject: [PATCH 05/11] now working --- .../ode_seir_contact_matrix_example.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/pycode/examples/simulation/ode_seir_contact_matrix_example.py b/pycode/examples/simulation/ode_seir_contact_matrix_example.py index 85c5f95997..68f003e187 100644 --- a/pycode/examples/simulation/ode_seir_contact_matrix_example.py +++ b/pycode/examples/simulation/ode_seir_contact_matrix_example.py @@ -51,12 +51,14 @@ def get_population_by_age(country_name: str): """ try: - # download the file and save only in variable + # Download without saving to disk response = requests.get(POPULATION_URL, timeout=10) response.raise_for_status() + + # Read into Pandas df = pd.read_csv(io.StringIO(response.text)) - # clean column names + # Clean column names df.columns = df.columns.str.strip() # Filter by country (case-insensitive) @@ -66,16 +68,17 @@ def get_population_by_age(country_name: str): if row.empty: raise ValueError(f"Country '{country_name}' not found in data.") - # Extract relevant columns + # Extract relevant columns (age0, age5, ..., age100) age_cols = [c for c in df.columns if c.startswith('age')] + + # Extract data pop_data = row.iloc[0][age_cols].astype(float) # Convert from thousands to absolute numbers pop_data = pop_data * 1000 - # --- Aggregation for Memilio / Prem Matrices (usually up to 75+) --- - # Prem matrices in memilio often end at 75+. - # The CSV goes up to 100. We sum everything from 75 upwards. + # Contact matrices end at 75+. + # The CSV goes up to 100. So, we sum everything from 75 upwards. # Columns up to 70 (0-4, ..., 70-74) cols_up_to_70 = [f'age{i}' for i in range(0, 75, 5)] @@ -85,15 +88,17 @@ def get_population_by_age(country_name: str): f'age{i}' for i in range(75, 105, 5) if f'age{i}' in pop_data.index] - final_pop = [] - - # 1. Keep groups up to 70-74 - for col in cols_up_to_70: - final_pop.append(pop_data[col]) + # Define labels for the output dict + labels = [ + "0-4", "5-9", "10-14", "15-19", "20-24", "25-29", "30-34", "35-39", + "40-44", "45-49", "50-54", "55-59", "60-64", "65-69", "70-74", + "75+"] - # 2. Sum groups from 75+ + final_pop = {} + for i, col in enumerate(cols_up_to_70): + final_pop[labels[i]] = pop_data[col] sum_75_plus = pop_data[cols_75_plus].sum() - final_pop.append(sum_75_plus) + final_pop["75+"] = sum_75_plus return final_pop @@ -119,12 +124,13 @@ def build_country_seir_model( model.parameters.ContactPatterns.cont_freq_mat[0] = contacts # Use actual population distribution - group_pop = np.array(population_by_age) + # transform dict to numpy array + group_pop = np.array(list(population_by_age.values())).flatten() if len(group_pop) != num_groups: # If dimensions mismatch (e.g. contact matrix has different age groups), # we might need to adjust. For now, we assume they match (16 groups for Prem 75+). - # If not, we fallback to uniform distribution of total sum (not ideal but safe). + # If not, we fallback to uniform distribution of total sum. print( f"Warning: Population groups ({len(group_pop)}) do not match contact matrix groups ({num_groups}). Using uniform distribution.") total_pop = np.sum(group_pop) @@ -161,7 +167,7 @@ def simulate_country_seir( infected_share: float = 5e-6, interpolate: bool = True): """ - Load contact matrix, fetch population, build the ODE SEIR model, and run + Load contact matrix, population data, build the ODE SEIR model, and run a simulation. """ contacts = load_contact_matrix(country, reduce_to_rki_groups=False) @@ -201,5 +207,5 @@ def run_demo(country: str, if __name__ == "__main__": - country = "Germany" + country = "India" run_demo(country=country) From b5ed737ccfd9c9477a8b5e5f5a9907f899d7272a Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 17 Dec 2025 11:26:21 +0100 Subject: [PATCH 06/11] add plots --- .../ode_seir_contact_matrix_example.py | 60 +++++++++++++++++-- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/pycode/examples/simulation/ode_seir_contact_matrix_example.py b/pycode/examples/simulation/ode_seir_contact_matrix_example.py index 68f003e187..74f60a81b9 100644 --- a/pycode/examples/simulation/ode_seir_contact_matrix_example.py +++ b/pycode/examples/simulation/ode_seir_contact_matrix_example.py @@ -25,12 +25,13 @@ import io +import matplotlib.pyplot as plt import numpy as np import pandas as pd import requests from memilio.epidata.getContactData import (load_contact_matrix) -from memilio.simulation import AgeGroup, ContactMatrix +from memilio.simulation import AgeGroup, ContactMatrix, Damping from memilio.simulation.oseir import InfectionState as State from memilio.simulation.oseir import (Model, interpolate_simulation_result, simulate) @@ -43,7 +44,7 @@ def get_population_by_age(country_name: str): """ - Loads population data for a specific country from a GitHub source (based on UN data). + Loads population data for a specific country from a POPULATION_URL. Returns the population in 5-year steps (Unit: number of people). Source: https://github.com/kieshaprem/synthetic-contact-matrices @@ -123,6 +124,10 @@ def build_country_seir_model( contacts.minimum = np.zeros_like(contact_matrix) model.parameters.ContactPatterns.cont_freq_mat[0] = contacts + # 60% contact reduction at t=30 + model.parameters.ContactPatterns.cont_freq_mat.add_damping(Damping( + coeffs=np.ones((num_groups, num_groups)) * 0.6, t=30.0, level=0, type=0)) + # Use actual population distribution # transform dict to numpy array group_pop = np.array(list(population_by_age.values())).flatten() @@ -162,7 +167,7 @@ def simulate_country_seir( country: str, days: float = 120.0, dt: float = 0.25, - transmission_probability: float = 0.06, + transmission_probability: float = 0.15, exposed_share: float = 1e-5, infected_share: float = 5e-6, interpolate: bool = True): @@ -186,12 +191,50 @@ def simulate_country_seir( return result +def plot_results(result, country: str): + """ + Plot aggregated SEIR compartments over time. + """ + results_arr = result.as_ndarray() + times = results_arr[0, :] + + num_groups = result.get_num_elements() // 4 # 4 compartments to aggregate + + susceptible = np.zeros(len(times)) + exposed = np.zeros(len(times)) + infected = np.zeros(len(times)) + recovered = np.zeros(len(times)) + + for age_group in range(num_groups): + susceptible += results_arr[1 + age_group * 4, :] + exposed += results_arr[2 + age_group * 4, :] + infected += results_arr[3 + age_group * 4, :] + recovered += results_arr[4 + age_group * 4, :] + + # Create plot + plt.figure(figsize=(10, 6)) + plt.plot(times, susceptible, label='Susceptible', linewidth=2) + plt.plot(times, exposed, label='Exposed', linewidth=2) + plt.plot(times, infected, label='Infected', linewidth=2) + plt.plot(times, recovered, label='Recovered', linewidth=2) + + plt.xlabel('Time (days)', fontsize=12) + plt.ylabel('Population', fontsize=12) + plt.title(f'SEIR Model Simulation - {country}', + fontsize=14, fontweight='bold') + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.show() + + def run_demo(country: str, days: float = 120.0, dt: float = 0.25, - transmission_probability: float = 0.06, + transmission_probability: float = 0.15, exposed_share: float = 1e-5, - infected_share: float = 5e-6): + infected_share: float = 5e-6, + plot: bool = True): """ Run the SEIR simulation demo for a user defined country and parameters. """ @@ -205,7 +248,12 @@ def run_demo(country: str, ) print(result.get_last_value()) + if plot: + plot_results(result, country) + + return result + if __name__ == "__main__": - country = "India" + country = "Germany" run_demo(country=country) From c9912e3e2ff5b0472f853d5e22d54437a9420a12 Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 17 Dec 2025 12:26:49 +0100 Subject: [PATCH 07/11] check if country available --- .../simulation/ode_seir_contact_matrix_example.py | 10 +++++++++- .../memilio-epidata/memilio/epidata/getContactData.py | 5 +++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pycode/examples/simulation/ode_seir_contact_matrix_example.py b/pycode/examples/simulation/ode_seir_contact_matrix_example.py index 74f60a81b9..4625e125da 100644 --- a/pycode/examples/simulation/ode_seir_contact_matrix_example.py +++ b/pycode/examples/simulation/ode_seir_contact_matrix_example.py @@ -30,7 +30,8 @@ import pandas as pd import requests -from memilio.epidata.getContactData import (load_contact_matrix) +from memilio.epidata.getContactData import (get_available_countries, + load_contact_matrix) from memilio.simulation import AgeGroup, ContactMatrix, Damping from memilio.simulation.oseir import InfectionState as State from memilio.simulation.oseir import (Model, interpolate_simulation_result, @@ -175,6 +176,13 @@ def simulate_country_seir( Load contact matrix, population data, build the ODE SEIR model, and run a simulation. """ + # Validate country + available = get_available_countries() + if country not in available: + raise ValueError( + f"Country '{country}' not available. " + f"Use get_available_countries() to see all {len(available)} supported countries.") + contacts = load_contact_matrix(country, reduce_to_rki_groups=False) population = get_population_by_age(country) model = build_country_seir_model( diff --git a/pycode/memilio-epidata/memilio/epidata/getContactData.py b/pycode/memilio-epidata/memilio/epidata/getContactData.py index ac749b81a6..41fa24645f 100644 --- a/pycode/memilio-epidata/memilio/epidata/getContactData.py +++ b/pycode/memilio-epidata/memilio/epidata/getContactData.py @@ -114,6 +114,11 @@ def list_available_contact_countries( return xls.sheet_names +def get_available_countries(contact_path: Optional[str] = None): + """Get list of all available countries.""" + return list_available_contact_countries(contact_path) + + def _select_sheet_name(country: str, sheet_names: Iterable[str]): lookup = {_normalize_country_name(name): name for name in sheet_names} key = _normalize_country_name(country) From 360b21c4fb71195df2296ffaa4d7662cde1e0f1d Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:41:14 +0100 Subject: [PATCH 08/11] [ci skip] review comment --- pycode/examples/simulation/ode_seir_contact_matrix_example.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pycode/examples/simulation/ode_seir_contact_matrix_example.py b/pycode/examples/simulation/ode_seir_contact_matrix_example.py index 4625e125da..5576980320 100644 --- a/pycode/examples/simulation/ode_seir_contact_matrix_example.py +++ b/pycode/examples/simulation/ode_seir_contact_matrix_example.py @@ -206,7 +206,9 @@ def plot_results(result, country: str): results_arr = result.as_ndarray() times = results_arr[0, :] - num_groups = result.get_num_elements() // 4 # 4 compartments to aggregate + num_compartments = 4 # S, E, I, R + # get_num_elements includes all compartments and age groups. Divide to get age groups. + num_groups = result.get_num_elements() // num_compartments susceptible = np.zeros(len(times)) exposed = np.zeros(len(times)) From 0c3d314b9375d7b263e951c4c769988c81621bb9 Mon Sep 17 00:00:00 2001 From: Henrik Zunker <69154294+HenrZu@users.noreply.github.com> Date: Thu, 18 Dec 2025 12:35:23 +0100 Subject: [PATCH 09/11] [ci skip] Update pycode/memilio-epidata/memilio/epidata/getContactData.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Martin J. Kühn <62713180+mknaranja@users.noreply.github.com> --- pycode/memilio-epidata/memilio/epidata/getContactData.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getContactData.py b/pycode/memilio-epidata/memilio/epidata/getContactData.py index 41fa24645f..ee4de946e9 100644 --- a/pycode/memilio-epidata/memilio/epidata/getContactData.py +++ b/pycode/memilio-epidata/memilio/epidata/getContactData.py @@ -141,8 +141,8 @@ def load_contact_matrix( :param country: Country name as listed in the workbook (case-insensitive). :param contact_path: Optional path to ``MUestimates_all_locations_1.xlsx``. - :param reduce_to_rki_groups: If True, aggregate to the 6 RKI age groups - (0-4, 5-14, 15-34, 35-59, 60-79, 80-99). Default True. + :param reduce_to_rki_groups: If True, aggregate to the six RKI age groups + (0-4, 5-14, 15-34, 35-59, 60-79, 80+ years). Default True. :returns: DataFrame indexed by age group with floats. """ xls_bytes = _load_workbook_bytes(contact_path) From b17c3fcde60ae4cb836dca8db7414083444ec70a Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:21:41 +0100 Subject: [PATCH 10/11] [ci skip] extend doc in getContactData --- .../memilio/epidata/getContactData.py | 48 +++++++++++++++++-- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getContactData.py b/pycode/memilio-epidata/memilio/epidata/getContactData.py index ee4de946e9..05ec70a56e 100644 --- a/pycode/memilio-epidata/memilio/epidata/getContactData.py +++ b/pycode/memilio-epidata/memilio/epidata/getContactData.py @@ -73,13 +73,24 @@ def _normalize_country_name(country: str): - """Return a case-insensitive key without whitespace or punctuation.""" + """ + Return a case-insensitive key without whitespace or punctuation. + + :param country: The country name to normalize. + :returns: Normalized country name. + """ return "".join(ch for ch in country.casefold() if ch.isalnum()) def _download_contact_workbook( url: str = CONTACT_ZIP_URL, target_filename: str = CONTACT_WORKBOOK_NAME): - """Download the ZIP from the DOI and return the workbook.""" + """ + Download the ZIP from the url and return the workbook. + + :param url: URL to download the ZIP from. + :param target_filename: Name of the workbook file within the ZIP. + :returns: Content of the workbook. + """ response = requests.get(url, timeout=30) response.raise_for_status() with zipfile.ZipFile(io.BytesIO(response.content)) as zf: @@ -96,7 +107,14 @@ def _load_workbook_bytes( contact_path: Optional[str], url: str = CONTACT_ZIP_URL, target_filename: str = CONTACT_WORKBOOK_NAME): - """Return workbook either from a user path or by downloading the ZIP.""" + """ + Return workbook either from a user path or by downloading the ZIP. + + :param contact_path: Optional local path to the workbook. + :param url: Url to download the ZIP from if no path is provided. + :param target_filename: Name of the workbook file within the ZIP. + :returns: Content of the workbook. + """ if contact_path: if not os.path.exists(contact_path): raise FileNotFoundError( @@ -108,18 +126,35 @@ def _load_workbook_bytes( def list_available_contact_countries( contact_path: Optional[str] = None): - """List all country names available in the contact matrix workbook.""" + """ + List all country names available in the contact matrix workbook. + + :param contact_path: Optional local path to the workbook. + :returns: List of all country names. + """ xls_bytes = _load_workbook_bytes(contact_path) xls = pd.ExcelFile(io.BytesIO(xls_bytes)) return xls.sheet_names def get_available_countries(contact_path: Optional[str] = None): - """Get list of all available countries.""" + """ + Get list of all available countries. + + :param contact_path: Optional local path to the workbook. + :returns: List of all available countries. + """ return list_available_contact_countries(contact_path) def _select_sheet_name(country: str, sheet_names: Iterable[str]): + """ + Select the appropriate sheet name from the workbook for a given country. + + :param country: Country name as listed in the workbook (case-insensitive). + :param sheet_names: Iterable of available sheet names. + :returns: The exact sheet name as found in the workbook. + """ lookup = {_normalize_country_name(name): name for name in sheet_names} key = _normalize_country_name(country) if key not in lookup: @@ -177,6 +212,9 @@ def _aggregate_to_rki_age_groups(matrix: pd.DataFrame): Assumes the original columns/rows follow AGE_GROUP_LABELS order. Note: The source only provides a 75+ group; we map it entirely to 80-99. + + :param matrix: The 16x16 contact matrix to aggregate. + :returns: Aggregated 6x6 contact matrix. """ if matrix.shape != (len(AGE_GROUP_LABELS), len(AGE_GROUP_LABELS)): raise ValueError( From 71954a6f3fe83b7c92cf2274a419ded41ec1950f Mon Sep 17 00:00:00 2001 From: HenrZu <69154294+HenrZu@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:56:56 +0100 Subject: [PATCH 11/11] [ci skip] comment extension 60-74 --- pycode/memilio-epidata/memilio/epidata/getContactData.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pycode/memilio-epidata/memilio/epidata/getContactData.py b/pycode/memilio-epidata/memilio/epidata/getContactData.py index 05ec70a56e..717a3ef1be 100644 --- a/pycode/memilio-epidata/memilio/epidata/getContactData.py +++ b/pycode/memilio-epidata/memilio/epidata/getContactData.py @@ -211,7 +211,8 @@ def _aggregate_to_rki_age_groups(matrix: pd.DataFrame): Aggregate a 16x16 age contact matrix to the 6-group RKI scheme. Assumes the original columns/rows follow AGE_GROUP_LABELS order. - Note: The source only provides a 75+ group; we map it entirely to 80-99. + Note: The source only provides data up to 70-74 and a 75+ group. + We map 60-74 to the 60-79 RKI group and 75+ to the 80-99 RKI group. :param matrix: The 16x16 contact matrix to aggregate. :returns: Aggregated 6x6 contact matrix.