Skip to content
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ee90231
Setup for future work
merkuns Jul 17, 2025
9ad885c
fixed imports
merkuns Jul 21, 2025
843a619
removed hash function
merkuns Jul 21, 2025
e744a35
updated model formulation (non-negative, Rackauckas) and added a prop…
merkuns Jul 25, 2025
3e586bd
adapted simulation to new model formulation
merkuns Jul 25, 2025
e721f9f
Evaluator now working for UDE system (as specified in the case study …
merkuns Jul 31, 2025
1edd7cc
Got UDESolver to work with external input data
merkuns Aug 1, 2025
f990481
Made infer_ode_states() compatible with UDEs
merkuns Aug 1, 2025
3a49fec
Test for UDESolver
merkuns Aug 7, 2025
f46868e
OptaxBackend + adjacent functions first commit (no multiple runs, no …
merkuns Aug 19, 2025
c701785
New model formulation (superclass UDEBase + subclass Func)
merkuns Aug 19, 2025
85f6fc0
Enabled multiple run optimization, added sim.inferer.run() method and…
merkuns Aug 20, 2025
1e87bd6
Added progress bar, enabled saving of optax config parameters and mad…
merkuns Aug 22, 2025
8a272da
Removed all mentions of args from the UDESolver and added x_in to sta…
merkuns Aug 25, 2025
57df392
Added possibility to split data into training and validation data
merkuns Aug 26, 2025
5feb146
Bugfix for 1D models
merkuns Aug 26, 2025
d1348b4
Moved UDEBase and helper functions to utils and made UDEBase's hash f…
merkuns Aug 26, 2025
212afbc
Made UDE distributions follow SciPy protocol
merkuns Aug 26, 2025
adb0515
Bugfix + more and better warnings + NotImplementedErrors
merkuns Aug 27, 2025
788bd77
New setup for case study
merkuns Aug 28, 2025
8a9e341
Bugfixes + changes to test_solvers to reflect new model formulation
merkuns Aug 29, 2025
aa012e7
Added UDE case study import to unit tests
merkuns Sep 1, 2025
b8e7ba1
Made equinox and optax optional dependencies
merkuns Sep 1, 2025
66bc6e6
UDE solver (slight changes) and inferer test (brand new!)
merkuns Sep 2, 2025
e00dc18
Removed posterior_predictive_checks(), added plot_posterior_predictio…
merkuns Sep 2, 2025
f690510
Bugfix + removed inference_optax.multiple_runs_plot from Config
merkuns Sep 2, 2025
ed6d30d
Made usage of y_mlp in the model formulation redundant
merkuns Sep 2, 2025
25f76af
New inferer using standard evaluator call (experimental) + fixed UDES…
merkuns Sep 3, 2025
5360b89
New testing notebooks
merkuns Sep 4, 2025
5ecd5d2
Added option to turn off clipping + alternative model optimization us…
merkuns Sep 5, 2025
cd3a175
Added provisional implementation of idata and posterior_predictive_ch…
merkuns Sep 16, 2025
04a2af4
New and much nicer model formulation
merkuns Sep 16, 2025
371a132
Some clean-up + preparations for hyperparameter search
merkuns Sep 22, 2025
7c4f692
Made activation functions customizable, removed automatic sorting of …
merkuns Oct 2, 2025
6706819
Bugfix + preparation for hyperparameter search
merkuns Oct 15, 2025
10d036b
Setup for hyperparameter search
merkuns Oct 17, 2025
625d6d7
Fixed memory issues and cleaned up the code a little
merkuns Oct 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ jobs:
python -m pip install flake8 pytest
pip install .[pyabc,pymoo,interactive,numpyro]
pip install -e case_studies/lotka_volterra_case_study
pip install -e case_studies/lotka_volterra_UDE_case_study

- name: Lint with flake8
if: env.full-test == 'true' && needs.decide-to-test.outputs.changes == 'true' && needs.decide-to-test.outputs.tagged_commit == 'false' && github.event_name == 'pull_request'
Expand Down
4 changes: 4 additions & 0 deletions case_studies/lotka_volterra_UDE_case_study/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__pycache__
results
*.code-workspace
*.egg-info
15 changes: 15 additions & 0 deletions case_studies/lotka_volterra_UDE_case_study/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: check-toml

- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: test.sh
language: script
pass_filenames: false
always_run: true
1 change: 1 addition & 0 deletions case_studies/lotka_volterra_UDE_case_study/__init__.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file needs to be deleted

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be leaving some comments here to preserve my current knowledge for later when we deal with the refactoring.

And you are completely right, this has to be deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "1.0.0"
Binary file not shown.
Binary file not shown.
482 changes: 482 additions & 0 deletions case_studies/lotka_volterra_UDE_case_study/interactive.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from . import data
from . import mod
from . import plot
from . import prob
from . import sim

__version__ = "1.0.0"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to make this version 0.1.0 (according to semantic versioning, 1.0.0 usually indicates a package that has matured a bit).

Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import equinox as eqx
import jax.nn as jnn
import jax.numpy as jnp
import jax
from pymob.utils.UDE import UDEBase

class Func(UDEBase):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a docstring here


mlp_depth: int = 3
mlp_width: int = 3
mlp_in_size: int = 2
mlp_out_size: int = 2

alpha: jax.Array
delta: jax.Array

def __init__(self, params, weights=None, bias=None, *, key, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a docstring here

self.init_MLP(weights, bias, key=key)
self.init_params(params)

def __call__(self, t, y):
"""
Returns the growth rates of predator and prey depending on their current state.

Parameters
----------
t : scalar
Just here to fulfill the requirements by diffeqsolve(). Has no effect and
can be set to None.
y : jax.ArrayImpl
Array containing two values: the current abundance of prey and predator,
respectively.

Returns:
--------
jax.ArrayImpl
An array containing the growth rates of prey and predators, respectively.
"""

params = self.preprocess_params()

prey, predator = y

dprey_dt_ode = params["alpha"] * prey
dpredator_dt_ode = - params["delta"] * predator
dprey_dt_nn, dpredator_dt_nn = self.mlp(y) * jnp.array([jnp.tanh(prey).astype(float), jnp.tanh(predator).astype(float)])

dprey_dt = dprey_dt_ode + dprey_dt_nn
dpredator_dt = dpredator_dt_ode + dpredator_dt_nn

return jnp.array([dprey_dt.astype(float),dpredator_dt.astype(float)])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a suggestion for how to refactor this:

    def __call__(self, t, y):
        """
        Returns the growth rates of predator and prey depending on their current state.

        Parameters
        ----------
        t : scalar
            Just here to fulfill the requirements by diffeqsolve(). Has no effect and
            can be set to None.
        y : jax.ArrayImpl
            Array containing two values: the current abundance of prey and predator,
            respectively.

        Returns:
        --------
        jax.ArrayImpl
            An array containing the growth rates of prey and predators, respectively.
        """

        params = self.preprocess_params()

        dprey_dt, dpredator_dt = self.model(t, y, mlp=self.mlp, **params)

        return jnp.array([dprey_dt.astype(float),dpredator_dt.astype(float)])
    
    @staticmethod
    def model(t, y, mlp, alpha, delta,):
        prey, predator = y
        
        dprey_dt_ode = alpha * prey 
        dpredator_dt_ode = - delta * predator
        dprey_dt_nn, dpredator_dt_nn = mlp(y) * jnp.array([jnp.tanh(prey).astype(float), jnp.tanh(predator).astype(float)])

        dprey_dt = dprey_dt_ode + dprey_dt_nn
        dpredator_dt = dpredator_dt_ode + dpredator_dt_nn

        return dprey_dt, dpredator_dt

Would that work?

This way model would satisfy a normal model description, could be used to infer the ode_states if sim.model is a callable class and not a method. Then the UDEBase.model signature could also be used to derive the needed arguments with mappar, etc. You could then also factor out the call method to UDEBase Think about it.

In general I like the idea of using class based callables for models very much, because they are more flexible. We could also think about defining a ModelBase (like SolverBase). Which could share common functionality like inferring the ode_states, mapping arguments, etc. This could make SimulationBase a bit more slim in the long run, which I'd like a lot. But this for sure does not concern your implementation, but maybe something to keep in mind. :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should work. I also think that pushing __call__() to the UDEBase should work. Then a greater share of the behind-the-scenes stuff would actually be hidden from the user which would be a great change.

The biggest problem with the current model formulation is that the parent class UDEBase knows nothing about the model parameters because they are only defined in the subclass. If a method that was defined in the UDEBase is inherited and then called by the subclass, information about the parameters is available and things start to work. That's the reason for the sometimes strange way this is set up, with some things happening in the UDEBase and some having to be done in the user-defined model. But I think your idea will make this much nicer :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ✅


class Func1D(UDEBase):

mlp_depth: int = 3
mlp_width: int = 3
mlp_in_size: int = 1
mlp_out_size: int = 1

r: jax.Array

def __init__(self, params, weights=None, bias=None, *, key, **kwargs):
self.init_MLP(weights, bias, key=key)
self.init_params(params)

def __call__(self, t, y):
"""
Returns the growth rates of predator and prey depending on their current state.

Parameters
----------
t : scalar
Just here to fulfill the requirements by diffeqsolve(). Has no effect and
can be set to None.
y : jax.ArrayImpl
Array containing two values: the current abundance of prey and predator,
respectively.

Returns:
--------
jax.ArrayImpl
An array containing the growth rates of prey and predators, respectively.
"""

params = self.preprocess_params()

X = y

dX_dt = params["r"] * X + self.mlp(y)

return jnp.array(dX_dt.astype(float))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from lotka_volterra_case_study.plot import *
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from lotka_volterra_case_study.prob import *
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from lotka_volterra_case_study.sim import Simulation_v2
72 changes: 72 additions & 0 deletions case_studies/lotka_volterra_UDE_case_study/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project]
name = "lotka_volterra_UDE_case_study"
version = "1.0.0"
authors = [
{ name="Florian Schunck", email="[email protected]" },
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be you :-)

]
description = "Lotka Volterra Predator-Prey case study"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add the UDE component here

readme = "README.md"
requires-python = ">=3.10"
dependencies=[
"pymob[numpyro] >= 0.5.0a19",
"preliz",
]
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"Natural Language :: English",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Bio-Informatics",
]

[project.urls]
"Homepage" = "https://github.com/flo-schu/lotka_volterra_case_study"
"Issue Tracker" = "https://github.com/flo-schu/lotka_volterra_case_study/issues"
Comment on lines +30 to +31
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update once the UDE case study is in a separate repo


[project.optional-dependencies]
dev = [
"pytest >= 7.3",
"bumpver",
"pre-commit",
"ipykernel",
"ipywidgets"
]

[tool.setuptools.packages.find]
include = ["lotka_volterra_UDE_case_study*"]

[tool.bumpver]
current_version = "1.0.0"
version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
commit_message = "bump version {old_version} -> {new_version}"
tag_message = "{new_version}"
tag_scope = "default"
pre_commit_hook = ""
post_commit_hook = ""
commit = true
tag = true
push = true

[tool.bumpver.file_patterns]
"pyproject.toml" = [
'current_version = "{version}"',
'version = "{version}"'
]
"lotka_volterra_case_study/__init__.py" = [
'__version__ = "{version}"'
]
"README.md" = [
'git clone [email protected]:flo-schu/lotka_volterra_case_study/{version}'
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above

]

[tool.pytest.ini_options]
markers = [
"slow='mark test as slow.'"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
[case-study]
name = lotka_volterra_UDE_case_study
pymob_version = 0.6.4
scenario = InfererTest
package = case_studies
modules = sim mod prob data plot
simulation = Simulation
observations = UDE_obs_inferer_test.nc
logging = DEBUG

[simulation]
y0 =
x_in =
input_files =
n_ode_states = 2
batch_dimension = batch_id
x_dimension = time
modeltype = deterministic
seed = 1

[data-structure]
prey = dimensions=['batch_id','time'] min=nan max=nan observed=True
predator = dimensions=['batch_id','time'] min=0.11841753125190735 max=5.719013214111328 observed=True

[solverbase]
x_dim = time
exclude_kwargs_model = t time x_in y x Y X
exclude_kwargs_postprocessing = t time interpolation results

[jax-solver]
diffrax_solver = Dopri5
rtol = 1e-06
atol = 1e-07
pcoeff = 0.0
icoeff = 1.0
dcoeff = 0.0
max_steps = 100000
throw_exception = True

[inference]
eps = 1e-08
objective_function = total_average
n_objectives = 1
objective_names =
extra_vars =
n_predictions = 100

[model-parameters]

[error-model]

[multiprocessing]
cores = 1

[inference.pyabc]
sampler = SingleCoreSampler
population_size = 100
minimum_epsilon = 0.0
min_eps_diff = 0.0
max_nr_populations = 1000
database_path = C:\Users\Markus\AppData\Local\Temp/pyabc.db

[inference.pyabc.redis]
password = nopassword
port = 1111
eval.n_predictions = 50
eval.history_id = -1
eval.model_id = 0

[inference.pymoo]
algortihm = UNSGA3
population_size = 100
max_nr_populations = 1000
ftol = 1e-05
xtol = 1e-07
cvtol = 1e-07
verbose = True

[inference.numpyro]
gaussian_base_distribution = False
kernel = nuts
init_strategy = init_to_uniform
chains = 1
draws = 2000
warmup = 1000
thinning = 1
nuts_draws = 2000
nuts_step_size = 0.8
nuts_max_tree_depth = 10
nuts_target_accept_prob = 0.8
nuts_dense_mass = True
nuts_adapt_step_size = True
nuts_adapt_mass_matrix = True
svi_iterations = 10000
svi_learning_rate = 0.0001

[inference.optax]
UDE_parameters = alpha = value=1.3 dims=[] hyper=False free=False delta = value=1.8 dims=[] prior=uniform(loc=1.0,scale=2.0) hyper=False free=True
MLP_weight_dist = normal()
MLP_bias_dist = normal()
length_strategy = 0.1 1
steps_strategy = 1000 1000
lr_strategy = 0.003 0.003
clip_strategy = 0.1 0.1
batch_size = 32
data_split = 0.8
multiple_runs_target = 3
multiple_runs_limit = 5
multiple_runs_plot = 5

[report]
debug_report = False
pandoc_output_format = html
model = True
parameters = True
parameters_format = pandas
diagnostics = True
diagnostics_with_batch_dim_vars = False
diagnostics_exclude_vars =
goodness_of_fit = True
goodness_of_fit_use_predictions = True
goodness_of_fit_nrmse_mode = range
table_parameter_estimates = True
table_parameter_estimates_format = csv
table_parameter_estimates_significant_figures = 3
table_parameter_estimates_error_metric = sd
table_parameter_estimates_parameters_as_rows = True
table_parameter_estimates_with_batch_dim_vars = False
table_parameter_estimates_exclude_vars =
table_parameter_estimates_override_names =
plot_trace = True
plot_parameter_pairs = True

Loading
Loading