-
Notifications
You must be signed in to change notification settings - Fork 0
Ude implementation: UDESolver to evaluate UDE models + OptaxBackend to find optimal solution #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 30 commits
ee90231
9ad885c
843a619
e744a35
3e586bd
e721f9f
1edd7cc
f990481
3a49fec
f46868e
c701785
85f6fc0
1e87bd6
8a272da
57df392
5feb146
d1348b4
212afbc
adb0515
788bd77
8a9e341
aa012e7
b8e7ba1
66bc6e6
e00dc18
f690510
ed6d30d
25f76af
5360b89
5ecd5d2
cd3a175
04a2af4
371a132
7c4f692
6706819
10d036b
625d6d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| __pycache__ | ||
| results | ||
| *.code-workspace | ||
| *.egg-info |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| repos: | ||
| - repo: https://github.com/pre-commit/pre-commit-hooks | ||
| rev: v2.3.0 | ||
| hooks: | ||
| - id: check-yaml | ||
| - id: check-toml | ||
|
|
||
| - repo: local | ||
| hooks: | ||
| - id: pytest-check | ||
| name: pytest-check | ||
| entry: test.sh | ||
| language: script | ||
| pass_filenames: false | ||
| always_run: true |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| __version__ = "1.0.0" |
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| from . import data | ||
| from . import mod | ||
| from . import plot | ||
| from . import prob | ||
| from . import sim | ||
|
|
||
| __version__ = "1.0.0" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be useful to make this version 0.1.0 (according to semantic versioning, 1.0.0 usually indicates a package that has matured a bit). |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| import equinox as eqx | ||
| import jax.nn as jnn | ||
| import jax.numpy as jnp | ||
| import jax | ||
| from pymob.utils.UDE import UDEBase | ||
|
|
||
| class Func(UDEBase): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a docstring here |
||
|
|
||
| mlp_depth: int = 3 | ||
| mlp_width: int = 3 | ||
| mlp_in_size: int = 2 | ||
| mlp_out_size: int = 2 | ||
|
|
||
| alpha: jax.Array | ||
| delta: jax.Array | ||
|
|
||
| def __init__(self, params, weights=None, bias=None, *, key, **kwargs): | ||
|
||
| self.init_MLP(weights, bias, key=key) | ||
| self.init_params(params) | ||
|
|
||
| def __call__(self, t, y): | ||
| """ | ||
| Returns the growth rates of predator and prey depending on their current state. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| t : scalar | ||
| Just here to fulfill the requirements by diffeqsolve(). Has no effect and | ||
| can be set to None. | ||
| y : jax.ArrayImpl | ||
| Array containing two values: the current abundance of prey and predator, | ||
| respectively. | ||
|
|
||
| Returns: | ||
| -------- | ||
| jax.ArrayImpl | ||
| An array containing the growth rates of prey and predators, respectively. | ||
| """ | ||
|
|
||
| params = self.preprocess_params() | ||
|
|
||
| prey, predator = y | ||
|
|
||
| dprey_dt_ode = params["alpha"] * prey | ||
| dpredator_dt_ode = - params["delta"] * predator | ||
| dprey_dt_nn, dpredator_dt_nn = self.mlp(y) * jnp.array([jnp.tanh(prey).astype(float), jnp.tanh(predator).astype(float)]) | ||
|
|
||
| dprey_dt = dprey_dt_ode + dprey_dt_nn | ||
| dpredator_dt = dpredator_dt_ode + dpredator_dt_nn | ||
|
|
||
| return jnp.array([dprey_dt.astype(float),dpredator_dt.astype(float)]) | ||
|
||
|
|
||
| class Func1D(UDEBase): | ||
|
|
||
| mlp_depth: int = 3 | ||
| mlp_width: int = 3 | ||
| mlp_in_size: int = 1 | ||
| mlp_out_size: int = 1 | ||
|
|
||
| r: jax.Array | ||
|
|
||
| def __init__(self, params, weights=None, bias=None, *, key, **kwargs): | ||
| self.init_MLP(weights, bias, key=key) | ||
| self.init_params(params) | ||
|
|
||
| def __call__(self, t, y): | ||
| """ | ||
| Returns the growth rates of predator and prey depending on their current state. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| t : scalar | ||
| Just here to fulfill the requirements by diffeqsolve(). Has no effect and | ||
| can be set to None. | ||
| y : jax.ArrayImpl | ||
| Array containing two values: the current abundance of prey and predator, | ||
| respectively. | ||
|
|
||
| Returns: | ||
| -------- | ||
| jax.ArrayImpl | ||
| An array containing the growth rates of prey and predators, respectively. | ||
| """ | ||
|
|
||
| params = self.preprocess_params() | ||
|
|
||
| X = y | ||
|
|
||
| dX_dt = params["r"] * X + self.mlp(y) | ||
|
|
||
| return jnp.array(dX_dt.astype(float)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from lotka_volterra_case_study.plot import * |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from lotka_volterra_case_study.prob import * |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from lotka_volterra_case_study.sim import Simulation_v2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
|
|
||
| [build-system] | ||
| requires = ["setuptools"] | ||
| build-backend = "setuptools.build_meta" | ||
|
|
||
| [project] | ||
| name = "lotka_volterra_UDE_case_study" | ||
| version = "1.0.0" | ||
| authors = [ | ||
| { name="Florian Schunck", email="[email protected]" }, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be you :-) |
||
| ] | ||
| description = "Lotka Volterra Predator-Prey case study" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add the UDE component here |
||
| readme = "README.md" | ||
| requires-python = ">=3.10" | ||
| dependencies=[ | ||
| "pymob[numpyro] >= 0.5.0a19", | ||
| "preliz", | ||
| ] | ||
| license = {file = "LICENSE"} | ||
| classifiers = [ | ||
| "Development Status :: 4 - Beta", | ||
| "Programming Language :: Python :: 3", | ||
| "Natural Language :: English", | ||
| "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", | ||
| "Operating System :: OS Independent", | ||
| "Topic :: Scientific/Engineering :: Bio-Informatics", | ||
| ] | ||
|
|
||
| [project.urls] | ||
| "Homepage" = "https://github.com/flo-schu/lotka_volterra_case_study" | ||
| "Issue Tracker" = "https://github.com/flo-schu/lotka_volterra_case_study/issues" | ||
|
Comment on lines
+30
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update once the UDE case study is in a separate repo |
||
|
|
||
| [project.optional-dependencies] | ||
| dev = [ | ||
| "pytest >= 7.3", | ||
| "bumpver", | ||
| "pre-commit", | ||
| "ipykernel", | ||
| "ipywidgets" | ||
| ] | ||
|
|
||
| [tool.setuptools.packages.find] | ||
| include = ["lotka_volterra_UDE_case_study*"] | ||
|
|
||
| [tool.bumpver] | ||
| current_version = "1.0.0" | ||
| version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]" | ||
| commit_message = "bump version {old_version} -> {new_version}" | ||
| tag_message = "{new_version}" | ||
| tag_scope = "default" | ||
| pre_commit_hook = "" | ||
| post_commit_hook = "" | ||
| commit = true | ||
| tag = true | ||
| push = true | ||
|
|
||
| [tool.bumpver.file_patterns] | ||
| "pyproject.toml" = [ | ||
| 'current_version = "{version}"', | ||
| 'version = "{version}"' | ||
| ] | ||
| "lotka_volterra_case_study/__init__.py" = [ | ||
| '__version__ = "{version}"' | ||
| ] | ||
| "README.md" = [ | ||
| 'git clone [email protected]:flo-schu/lotka_volterra_case_study/{version}' | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as above |
||
| ] | ||
|
|
||
| [tool.pytest.ini_options] | ||
| markers = [ | ||
| "slow='mark test as slow.'" | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| [case-study] | ||
| name = lotka_volterra_UDE_case_study | ||
| pymob_version = 0.6.4 | ||
| scenario = InfererTest | ||
| package = case_studies | ||
| modules = sim mod prob data plot | ||
| simulation = Simulation | ||
| observations = UDE_obs_inferer_test.nc | ||
| logging = DEBUG | ||
|
|
||
| [simulation] | ||
| y0 = | ||
| x_in = | ||
| input_files = | ||
| n_ode_states = 2 | ||
| batch_dimension = batch_id | ||
| x_dimension = time | ||
| modeltype = deterministic | ||
| seed = 1 | ||
|
|
||
| [data-structure] | ||
| prey = dimensions=['batch_id','time'] min=nan max=nan observed=True | ||
| predator = dimensions=['batch_id','time'] min=0.11841753125190735 max=5.719013214111328 observed=True | ||
|
|
||
| [solverbase] | ||
| x_dim = time | ||
| exclude_kwargs_model = t time x_in y x Y X | ||
| exclude_kwargs_postprocessing = t time interpolation results | ||
|
|
||
| [jax-solver] | ||
| diffrax_solver = Dopri5 | ||
| rtol = 1e-06 | ||
| atol = 1e-07 | ||
| pcoeff = 0.0 | ||
| icoeff = 1.0 | ||
| dcoeff = 0.0 | ||
| max_steps = 100000 | ||
| throw_exception = True | ||
|
|
||
| [inference] | ||
| eps = 1e-08 | ||
| objective_function = total_average | ||
| n_objectives = 1 | ||
| objective_names = | ||
| extra_vars = | ||
| n_predictions = 100 | ||
|
|
||
| [model-parameters] | ||
|
|
||
| [error-model] | ||
|
|
||
| [multiprocessing] | ||
| cores = 1 | ||
|
|
||
| [inference.pyabc] | ||
| sampler = SingleCoreSampler | ||
| population_size = 100 | ||
| minimum_epsilon = 0.0 | ||
| min_eps_diff = 0.0 | ||
| max_nr_populations = 1000 | ||
| database_path = C:\Users\Markus\AppData\Local\Temp/pyabc.db | ||
|
|
||
| [inference.pyabc.redis] | ||
| password = nopassword | ||
| port = 1111 | ||
| eval.n_predictions = 50 | ||
| eval.history_id = -1 | ||
| eval.model_id = 0 | ||
|
|
||
| [inference.pymoo] | ||
| algortihm = UNSGA3 | ||
| population_size = 100 | ||
| max_nr_populations = 1000 | ||
| ftol = 1e-05 | ||
| xtol = 1e-07 | ||
| cvtol = 1e-07 | ||
| verbose = True | ||
|
|
||
| [inference.numpyro] | ||
| gaussian_base_distribution = False | ||
| kernel = nuts | ||
| init_strategy = init_to_uniform | ||
| chains = 1 | ||
| draws = 2000 | ||
| warmup = 1000 | ||
| thinning = 1 | ||
| nuts_draws = 2000 | ||
| nuts_step_size = 0.8 | ||
| nuts_max_tree_depth = 10 | ||
| nuts_target_accept_prob = 0.8 | ||
| nuts_dense_mass = True | ||
| nuts_adapt_step_size = True | ||
| nuts_adapt_mass_matrix = True | ||
| svi_iterations = 10000 | ||
| svi_learning_rate = 0.0001 | ||
|
|
||
| [inference.optax] | ||
| UDE_parameters = alpha = value=1.3 dims=[] hyper=False free=False delta = value=1.8 dims=[] prior=uniform(loc=1.0,scale=2.0) hyper=False free=True | ||
| MLP_weight_dist = normal() | ||
| MLP_bias_dist = normal() | ||
| length_strategy = 0.1 1 | ||
| steps_strategy = 1000 1000 | ||
| lr_strategy = 0.003 0.003 | ||
| clip_strategy = 0.1 0.1 | ||
| batch_size = 32 | ||
| data_split = 0.8 | ||
| multiple_runs_target = 3 | ||
| multiple_runs_limit = 5 | ||
| multiple_runs_plot = 5 | ||
|
|
||
| [report] | ||
| debug_report = False | ||
| pandoc_output_format = html | ||
| model = True | ||
| parameters = True | ||
| parameters_format = pandas | ||
| diagnostics = True | ||
| diagnostics_with_batch_dim_vars = False | ||
| diagnostics_exclude_vars = | ||
| goodness_of_fit = True | ||
| goodness_of_fit_use_predictions = True | ||
| goodness_of_fit_nrmse_mode = range | ||
| table_parameter_estimates = True | ||
| table_parameter_estimates_format = csv | ||
| table_parameter_estimates_significant_figures = 3 | ||
| table_parameter_estimates_error_metric = sd | ||
| table_parameter_estimates_parameters_as_rows = True | ||
| table_parameter_estimates_with_batch_dim_vars = False | ||
| table_parameter_estimates_exclude_vars = | ||
| table_parameter_estimates_override_names = | ||
| plot_trace = True | ||
| plot_parameter_pairs = True | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file needs to be deleted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be leaving some comments here to preserve my current knowledge for later when we deal with the refactoring.
And you are completely right, this has to be deleted.