Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions src/tlo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -120,6 +120,14 @@ def __repr__(self) -> str:

class Parameter(Specifiable):
"""Used to specify parameters for disease modules etc."""
def __init__(self,
type_: Types,
description: str,
categories: List[str] = None,
*,
metadata: Optional[Dict[str, Any]] = None):
super().__init__(type_, description, categories)
self.metadata = metadata or {}


class Property(Specifiable):
Expand Down Expand Up @@ -321,27 +329,41 @@ def load_parameters_from_dataframe(self, resource: pd.DataFrame) -> None:

:param DataFrame resource: DataFrame with a column of the parameter_name and a column of `value`
"""

resource.set_index('parameter_name', inplace=True)
skipped_data_types = ('DATA_FRAME', 'SERIES')
acceptable_labels = ['unassigned', 'undetermined', 'universal', 'local', 'scenario']
param_defaults = {'param_label': 'unassigned', 'prior_min': None, 'prior_max': None }

for _col in param_defaults.keys():
if _col not in resource.columns:
resource[_col] = param_defaults[_col]
# for each supported parameter, convert to the correct type
for parameter_name in resource.index[resource.index.notnull()]:
parameter_definition = self.PARAMETERS[parameter_name]

if parameter_definition.type_.name in skipped_data_types:
continue

# For each parameter, raise error if the value can't be coerced
parameter_value = resource.at[parameter_name, 'value']
parameter_value, prior_min, prior_max = resource.loc[parameter_name, ['value', 'prior_min', 'prior_max']]
parameter_label = resource.at[parameter_name, 'param_label']
assert parameter_label in acceptable_labels, f'unrecognised parameter label {parameter_label}'

error_message = (
f"The value of '{parameter_value}' for parameter '{parameter_name}' "
f"could not be parsed as a {parameter_definition.type_.name} data type"
f"some values are not of type {parameter_definition.type_.name} and "
f"could not be parsed as a {parameter_definition.type_.name} data type. "
f"parameter name is {parameter_name}, values {[parameter_value, prior_min, prior_max]}"
)
if parameter_definition.python_type is list:
try:
# chose json.loads instead of save_eval
# because it raises error instead of joining two strings without a comma
parameter_value = json.loads(parameter_value)
assert isinstance(parameter_value, list)
if pd.notnull(prior_min):
assert isinstance(json.loads(prior_min), list)
if pd.notnull(prior_max):
assert isinstance(json.loads(prior_max), list)
except (json.decoder.JSONDecodeError, TypeError, AssertionError) as exception:
raise ValueError(error_message) from exception
elif parameter_definition.python_type == pd.Categorical:
Expand All @@ -358,11 +380,22 @@ def load_parameters_from_dataframe(self, resource: pd.DataFrame) -> None:
# All other data types, assign to the python_type defined in Parameter class
try:
parameter_value = parameter_definition.python_type(parameter_value)
if not isinstance(parameter_definition.python_type, pd.Timestamp):
if pd.notnull(prior_min):
parameter_definition.python_type(prior_min)
if pd.notnull(prior_max):
parameter_definition.python_type(prior_max)
except Exception as exception:
raise ValueError(error_message) from exception

# Save the values to the parameters
self.parameters[parameter_name] = parameter_value
# Assign metadata to the Parameter object
parameter_definition.metadata.update(
param_label=parameter_label,
prior_min=prior_min,
prior_max=prior_max
)

def read_parameters(self, data_folder: str | Path) -> None:
"""Read parameter values from file, if required.
Expand Down
12 changes: 10 additions & 2 deletions src/tlo/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import heapq
import itertools
import time
from collections import OrderedDict
from collections import Counter, OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Optional

Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__(
self._custom_log_levels = None
self._log_filepath = self._configure_logging(**log_config)


# random number generator
seed_from = "auto" if seed is None else "user"
self._seed = seed
Expand Down Expand Up @@ -307,8 +308,15 @@ def finalise(self, wall_clock_time: Optional[float] = None) -> None:
:param wall_clock_time: Optional argument specifying total time taken to
simulate, to be written out to log before closing.
"""
for module in self.modules.values():
for module_name, module in self.modules.items():
module.on_simulation_end()
if hasattr(module, "PARAMETERS"):
# collect the module's parameter labels
labels = [p.metadata.get("param_label", "not_init_via_load_param") for p in module.PARAMETERS.values()]
labels = Counter(labels)
for label, count in labels.items():
logger.info(key="parameter_stats", data={"module": module_name, "label": label, "count": count})

if wall_clock_time is not None:
logger.info(key="info", data=f"simulate() {wall_clock_time} s")
self.close_output_file()
Expand Down
37 changes: 37 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,43 @@ def test_bools(self):
assert self.module.parameters['bool_true'] is True
assert self.module.parameters['bool_false'] is False

def test_unacceptable_labels(self):
""" label not acceptable for parameter label

should raise an assertion error """
resource = self.resource.copy()
resource['param_label'] = 'free'
with pytest.raises(AssertionError, match="unrecognised parameter label"):
self.module.load_parameters_from_dataframe(resource)

def test_unacceptable_lower_value(self):
""" check unacceptable for lower value

should raise a value error """
resource = self.resource.copy()
resource['prior_min'] = 'a'
with pytest.raises(ValueError):
self.module.load_parameters_from_dataframe(resource)

def test_unacceptable_upper_value(self):
""" check unacceptable for upper value

should raise a value error """
resource = self.resource.copy()
resource['prior_max'] = 'b'
with pytest.raises(ValueError):
self.module.load_parameters_from_dataframe(resource)

def test_list_type_parameter_value_has_list_type_lower_upper_value(self):
""" assign integer and float values to lower and upper values respectively.

should raise a value error for parameter values of type list """
resource = self.resource.copy()
resource['prior_min'] = 1
resource['prior_max'] = 2.0
with pytest.raises(ValueError, match='some values are not of type LIST'):
self.module.load_parameters_from_dataframe(resource)


class TestLoadParametersFromDataframe_Bools_From_Csv:
"""Tests for the load_parameters_from_dataframe method, including handling of bools when loading from csv"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _check_parsed_logs_are_equal(
if key == "_metadata":
assert module_logs_1[key] == module_logs_2[key]
elif (module_name, key) not in module_name_key_pairs_to_skip:
assert module_logs_1[key].equals(module_logs_2[key])
assert module_logs_1[key].equals(module_logs_2[key]), f"{module_name} log {key} not equal"


@pytest.mark.slow
Expand Down