Skip to content

Conversation

@merkuns
Copy link
Collaborator

@merkuns merkuns commented Aug 4, 2025

The relevant changes are the ones to...

  • pymob/sim/evaluator.py
  • pymob.simulation.py
  • pymob/solvers.diffrax.py

The rest is just there for testing purposes. As I said before, I intend to write a proper testing file but this will not be ready before your vacation. I still wanted to create a pull request even in this state to at least show you the most relevant changes.

@merkuns merkuns changed the title Ude implementation: UDESolver to evaluate UDE models Ude implementation: UDESolver to evaluate UDE models + OptaxBackend to find optimal solution Aug 19, 2025
@merkuns merkuns added enhancement New feature or request solver Solver enhancements for the simulation and models simulation Improvements and fixes of the Simulation class labels Aug 19, 2025
…e cosmetic changes to posterior_predictive_checks()
@merkuns
Copy link
Collaborator Author

merkuns commented Aug 22, 2025

The UDE toolbox is now mostly complete but still has to be thoroughly tested. The unit tests on GitHub are failing but this might be due to imports not working because of the "wrong" folder structure. On my computer, the tests in test_solvers.py for the UDE solver are working but I haven't written a test_backend_optax.py file for the inferer yet.

Some things that still need to be done (probably not a complete list):

  • Remove all mentions of args from the UDE solver.
  • Add x_in to the standalone solver, the inferer, and the posterior predictive checks.
  • Add the possibility to split the data into training and validation data.
  • If the scipy_to_jax dictionary has the purpose that I think it has, change it a little bit.
  • Add NotImplementedError for all functionalities that don't work with UDEs.
  • Move UDEBase to a suitable place (maybe \pymob\utils\udebase.py?)
  • Write a test_backend_optax.py file to test the inferer.
  • Write a documentation for the new features.
  • Write a tutorial on how to create and run a UDE model in Pymob.

But wait, there is more:

  • Change inferer so it uses the standard evaluator() function instead of evaluator.standalone_solver()
  • Move plot from posterior_predictive_checks() to plot_posterior()
  • Make equinox an optional dependency
  • Rename multiple_runs_plot -> not necessary anymore, should be removed from config
  • Make changes in UDESolver to accommodate different requirements of equinox and pymob with respect to y(0)
  • Fix bugs

Plan for the next weeks/months:

  • Add idata output
  • Get sim.report() to work with UDEs
  • Write a script that takes some hyperparameters as input and returns idata
  • But most importantly: Start writing your master thesis!

Plan for the next year:

  • Refactor the code so it's more in line with Pymob

…ndalone solver, inferer and posterior predictive checks
@merkuns merkuns linked an issue Aug 25, 2025 that may be closed by this pull request
3 tasks
@merkuns merkuns requested a review from flo-schu August 28, 2025 07:32
@flo-schu
Copy link
Owner

Hi @merkuns

I see why the tests are currently failing:

- name: Install dependencies
if: env.full-test == 'true' && needs.decide-to-test.outputs.changes == 'true' && needs.decide-to-test.outputs.tagged_commit == 'false' && github.event_name == 'pull_request'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
pip install .[pyabc,pymoo,interactive,numpyro]
pip install -e case_studies/lotka_volterra_case_study

You need to add the following line in the .github/python-test.yml workflow: pip install -e case_studies/lotka_volterra_UDE_case_study after after pip install -e case_studies/lotka_volterra_case_study. Technically, the -e is not necessar, but use it anyway for consistency.

This installs the case study also in the testing environment of the github runner, which should fix the testing error in the remote computing environment

After this is done, I can have a look at the tracer test. I think though that this one is important, because this could be related to the problem you mentioned with the faulty repeated executions of the evaluator, which was solved in this PR (c11b4ac). What you need to to is to hunt for the occurrence of a tracer outside of a jitted functions. Very typical for this is the update of a dictionary. We can also look at it next week. I will also try then to have a proper look at the code.

But first of all, have a nice weekend!
Flo

Copy link
Owner

@flo-schu flo-schu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Markus,

thanks again for your implementation. It's great you have crafted a working implementation of this complex piece of software. There is of course still some need for refactoring, because my main remark is that the current implementation is often finding some workarounds (necessitated by the complicated codebase and the requirements of the new feature) by creating parallel implementations to the existing code instead of integrate the changes with the least changes possible. But I also realize that this was hardly possible in the given time considering the complexity of the task and the little information you had to start with :)

Nevertheless, I looked at the changes to the existing pymob codebase and tried to identify the sites where changes could be possible.

I have also quickly looked at the UDE.py file. Great work (and evidently not easy) to find a way to initialize an MLP with custom biases and weights. I still wonder whether it is possible to craft a UDE parameter, which contains the UDE structure, weights and biases and from there, in the backend somewhere constructs an MLP, which is placed at the a specific site in the model; maybe we can discuss this at some point.

I have not looked at the optax backend yet, because most of it is completely new to me and I didn't have time for an in depth review.

As discussed the refactor which I think is still necessary, has lowest priority for your thesis.

Before you go on to implement my change requests, I think it is my turn to see if what I have in mind is at all possible.

Again, thanks for implementing this complicated feature and I think for what you had to start with and the complexities of both the framework and the new feature you did a great job!

Flo

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file needs to be deleted

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be leaving some comments here to preserve my current knowledge for later when we deal with the refactoring.

And you are completely right, this has to be deleted.

from . import prob
from . import sim

__version__ = "1.0.0" No newline at end of file
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to make this version 0.1.0 (according to semantic versioning, 1.0.0 usually indicates a package that has matured a bit).

name = "lotka_volterra_UDE_case_study"
version = "1.0.0"
authors = [
{ name="Florian Schunck", email="[email protected]" },
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be you :-)

authors = [
{ name="Florian Schunck", email="[email protected]" },
]
description = "Lotka Volterra Predator-Prey case study"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add the UDE component here

Comment on lines +30 to +31
"Homepage" = "https://github.com/flo-schu/lotka_volterra_case_study"
"Issue Tracker" = "https://github.com/flo-schu/lotka_volterra_case_study/issues"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update once the UDE case study is in a separate repo

class Optax(PymobModel):
model_config = ConfigDict(validate_assignment=True, extra="ignore")

UDE_parameters: Annotated[Modelparameters, BeforeValidator(string_to_modelparams), serialize_modelparams_to_string] = Modelparameters()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is opening a second implementation line for defining model parameters. The optax section should only be used for defining how the algorithm operates not for the model description.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason the UDE parameters were separated from the "normal" model parameters was that Pymob expects the model parameters to appear as input variables of the model which is not the case for the current UDE model formulation.

If we make the change suggested by you above, this distinction might not be necessary anymore.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ✅
The model parameters in the config file were never the problem, sim.model_parameters["parameters"] was. So the UDE_parameters don't exist anymore.

UDE_parameters: Annotated[Modelparameters, BeforeValidator(string_to_modelparams), serialize_modelparams_to_string] = Modelparameters()
MLP_weight_dist: OptionRV = to_rv("normal()")
MLP_bias_dist: OptionRV = to_rv("normal()")
loss_function: Callable = lambda x_obs, x_pred: (x_obs - x_pred)**2
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think callables can be defined from a file, but you noticed that and excluded the loss function in the save method. This is a bit tricky. It is not a good form that the loss function can only be defined in script mode. This makes it unusable for the workflow config -> init -> execution, which is used for the commandline execution. Either parse the loss function from a string to a callable, but this is difficult. Or implement some default loss functions and choose from them by name. Or define them in the Model, which is possibly the most straightforward option

Copy link
Collaborator Author

@merkuns merkuns Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer the last option. It was a conscious decision to let the user define their own loss function because it is impossible to cover everything by calling them by some pre-defined names. As far as I know, the loss function is a pretty powerful tool to improve training. It seems to be absolutely common to include some prior knowledge about the expected outcome in the loss function. I actually did that as well in my early exploration by including a barrier function to prevent the state variables from getting too close to zero. So I really want to give the users maximum freedom in their choice of a loss function. If you say that the model class is a good spot for that, then we can easily have the users define it there. We can also set a default loss function by defining it in the UDEBase.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ✅

mode="json",
exclude_none=True,
exclude={"case_study": {"output_path", "data_path", "root", "init_root", "default_settings_path"}}
exclude={"case_study": {"output_path", "data_path", "root", "init_root", "default_settings_path"}, "inference_optax": {"loss_function"}}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove as mentioned earlier.

Comment on lines +242 to +252

model_solver = self.model

equinox = import_optional_dependency(
"equinox", errors="ignore"
)
if equinox is not None:
from pymob.solvers.diffrax import UDESolver
import equinox as eqx
if solver == UDESolver:
model_params, model_solver = eqx.partition(self.model, eqx.is_array)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is okay for now, but I would prefer if we found a more general solution, where the model, which is returned from the model class is hashable (static) by definition. I don't know really how this would work though; let's keep it like that for now.

Copy link
Collaborator Author

@merkuns merkuns Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do some research on that but I honestly think that there is no way to make the model class hashable inside of JIT. As far as I know, the Jax arrays are absolutely essential to equinox' calculation of gradients. And I still have found no way to hash them when JITted.

On the other hand, normally, eqx.filter_jit should treat every Jax and NumPy array as dynamic and everything else as static. In this case, the JAX arrays would not have to be hashable. And outside of JIT, hashing them is possible by transforming them to tuples, so the hash call in the evaluator post-init method should work if a hash function exists. But for some reason, the filtering done by eqx.filter_jit doesn't seem to work. Edit: I'm starting to suspect that this is due to JaxSolver not being a PyTree. Do you think this is plausible? Anyway, fixing this would probably require bigger changes than I can still make before generating my results so I'm concentrating on the idata thing now.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to have equinox.filter_jit also in the jaxsolver if it works with the other tests. In that case, the UDESolver can be reduced quite a bit, because most functions can be reused, correct? I think this would be good, so that also here we don't maintain duplicate implementations.

Copy link
Collaborator Author

@merkuns merkuns Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the UDESolver basically only contains copies of some of JaxSolver's functions with two changes being made:

  1. eqx.filter_jit was used instead of jax.jit. If we used eqx.filter_jit for the JaxSolver, equinox would not be an optional dependency anymore, though.
  2. The partitioned model is passed on to other functions or combined inside the functions by using eqx.combine. If we can somehow get rid of the need to partition the model, the UDESolver might not be necessary at all.

Copy link
Collaborator Author

@merkuns merkuns Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit:
3. The UDESolver does not refer to the "mechanistic" model parameters, as those are stored as an attribute of the model. Also, the model always expects an x_in input even if there is no input data which necessitated another small change. But these things will resolve themselves if we make model classes the standard way of defining a class in Pymob.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request simulation Improvements and fixes of the Simulation class solver Solver enhancements for the simulation and models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement Universal Differential Equations (UDEs)

2 participants