Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
# no more testing on push from dev, because I'm not going to push to a main or dev without
# an active PR
pull_request:
types: [opened, reopened, synchronize]
types: [opened, reopened, synchronize, ready_for_review]

env:
# How to get the name of the branch (see: https://stackoverflow.com/a/71158878)
Expand All @@ -19,6 +19,7 @@ env:

jobs:
decide-to-test:
if: github.event.pull_request.draft == false
name: "Test decision"
runs-on: ubuntu-latest
outputs:
Expand Down Expand Up @@ -94,6 +95,7 @@ jobs:


test:
if: github.event.pull_request.draft == false
name: 🧪 Unit Tests
needs: decide-to-test

Expand Down Expand Up @@ -162,6 +164,7 @@ jobs:
path: test_results/*.txt

upload-test-results:
if: github.event.pull_request.draft == false
# creates one artifact 'test-results' for the entire test matrix with all files
runs-on: ubuntu-latest
needs: test
Expand Down
95 changes: 61 additions & 34 deletions pymob/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import inspect
from scipy.ndimage import gaussian_filter1d
from diffrax import rectilinear_interpolation
from pymob.utils.errors import PymobError

@dataclass(frozen=True)
class SolverBase:
Expand Down Expand Up @@ -177,44 +178,47 @@ def _get_input_vars_shapes(self, ) -> frozendict[str, frozendict[str, Tuple[int,
})
return frozendict(input_vars_shape_dict)

# def _get_input_vars_shapes(self, ) -> frozendict[str, frozendict[str, Tuple[int, ...]]]:
# frozen_shapes_input_vars = frozendict({
# k_input_var: frozendict({
# k_datavar: tuple([
# len(v_coord)
# for _, v_coord in v_datavar.items()
# ])
# for k_datavar, v_datavar in v_input_var.items()
# })
# for k_input_var, v_input_var in self.coordinates_input_vars.items()
# })

# return frozen_shapes_input_vars

# raise NotImplementedError(
# "Currently, it is expected that x_in is already in the correct "+
# "shape (minus the batch dimension), when passed to the Evaluator. "+
# "This is why no definite shapes can be defined on initialization. "+
# "The only broadcasting that can be sensibly made is to promote the "+
# "input vectors to the batch dimension, if this has not yet been done."
# )

def test_matching_batch_dims(self):
bc = self.coordinates.get(self.batch_dimension, None)

if bc is not None:
matching_batch_coords_if_present = [
v[self.batch_dimension] == bc
for k, v in self.coordinates_input_vars.items()
if self.batch_dimension in v
]

if not all(matching_batch_coords_if_present):
raise IndexError(
f"Batch coordinates '{self.batch_dimension}' of input "+
"variables do not have the same size "+
"as the batch dimension of the observations."
)
for input_key, input_dict in self.coordinates_input_vars.items():
matching_batch_coords_if_present = {}
for k, v in input_dict.items():
if self.batch_dimension in v:
matching_batch_coords_if_present.update({k:
v[self.batch_dimension] == bc
})

if not all(list(matching_batch_coords_if_present.values())):
raise PymobError(
f"The batch coordinates of the '{input_key}' input variable "+
f"'{k}': {dict(v)} differ from the "+
f"batch coordinates of the observations {{'{self.batch_dimension}': {bc}.}} "+
"\n\n" +

"Why does this error occur?\n" +
"--------------------------\n" +
"Pymob internally requires that all model inputs (theta, x_in, y0) to have "+
"equally sized batch dimensions and broadcasts parameters and input "+
"according to the respective batch dimensions, if they are not "+
"provided in the expanded form. "+
"Differing coordinates would lead to inhomogeneous batch dimensions, "+
"which cannot be processed by the solver. "+
f"You have possibly set `sim.model_parameters['{input_key}']` " +
f"with an xarray.Dataset that has different {self.batch_dimension}"+
"-coordinates than `sim.observations`."
"\n\n"

"How to fix this error?\n"+
"----------------------\n"+
"Make sure all model inputs have the same batch coordinates. "+
"Try using the pymob method `SimulationBase.parse_input(...) -> "+
"https://pymob.readthedocs.io/en/stable/api/pymob.html#pymob.simulation.SimulationBase.parse_input "
"to prepare 'x_in' and 'y0'. E.g.: \n"+
"* `sim.model_parameters['y0'] = sim.parse_input('y0', drop_dims=['time'])`\n"+
"* `sim.model_parameters['x_in'] = sim.parse_input('x_in', reference_data=sim.observations)` "
)

def test_x_coordinates(self):
x = self.coordinates[self.x_dim]
Expand Down Expand Up @@ -260,6 +264,29 @@ def preprocess_parameters(self, parameters, num_backend: ModuleType=numpy):

return ode_args_broadcasted, pp_args_broadcasted

def test_batch_dim_consistency(self, X_in, Y_0, ode_args, pp_args, num_backend: ModuleType=numpy):
# This method is currently not used. It may come in handy later on, but it
# would need to be called during Evaluator.__call__(), which creates unnecessary
# overhead calculations, which would slow down pymob. Instead the check is done
# during the dispatch_constructor() call. Specifically in
# SolverBase.test_matching_batch_dims

shapes = {
# xin data shape
"theta": [oa.shape[0] for oa in ode_args] + [pa.shape[0] for pa in pp_args],
"y0": [y0.shape[0] for y0 in Y_0],
"x_in": [xin[0].shape[0] for xin in X_in],
}

_shapes = numpy.concatenate(list(shapes.values()))

if not all(x == _shapes[0] for x in _shapes):
raise PymobError(
f"The sizes of the batch dimensions ('{self.batch_dimension}') of theta, " +
f"x_in and y0 did not match: {shapes}. This problem is often caused if " +
"the components of sim.model_parameters have not been harmonized, "
)


def _broadcast_args(self, arg_dict: frozendict[str, numpy.ndarray], num_backend: ModuleType=numpy):
# simply broadcast the parameters along the batch dimension
Expand Down
11 changes: 11 additions & 0 deletions pymob/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
import importlib
import textwrap

class PymobError(Exception):
"""Exception raised for custom error scenarios.

Attributes:
message -- explanation of the error
"""

def __init__(self, message):
self.message = message
super().__init__(self.message)

def errormsg(msg):
return textwrap.fill(textwrap.dedent(msg))

Expand Down