-
Notifications
You must be signed in to change notification settings - Fork 0
Ude implementation: UDESolver to evaluate UDE models + OptaxBackend to find optimal solution #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
… sim.posterior_predictive_check() for UDEs
…e cosmetic changes to posterior_predictive_checks()
|
The UDE toolbox is now mostly complete but still has to be thoroughly tested. The unit tests on GitHub are failing but this might be due to imports not working because of the "wrong" folder structure. On my computer, the tests in Some things that still need to be done (probably not a complete list):
But wait, there is more:
Plan for the next weeks/months:
Plan for the next year:
|
…ndalone solver, inferer and posterior predictive checks
…unction universal (hopefully)
|
Hi @merkuns I see why the tests are currently failing: pymob/.github/workflows/python-test.yml Lines 121 to 128 in 03ed11d
You need to add the following line in the This installs the case study also in the testing environment of the github runner, which should fix the testing error in the remote computing environment After this is done, I can have a look at the tracer test. I think though that this one is important, because this could be related to the problem you mentioned with the faulty repeated executions of the evaluator, which was solved in this PR (c11b4ac). What you need to to is to hunt for the occurrence of a tracer outside of a jitted functions. Very typical for this is the update of a dictionary. We can also look at it next week. I will also try then to have a proper look at the code. But first of all, have a nice weekend! |
…ns() for UDE models
…olver test + removed small error from standalone solver
…ing evaluator() + small fix for dataloader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Markus,
thanks again for your implementation. It's great you have crafted a working implementation of this complex piece of software. There is of course still some need for refactoring, because my main remark is that the current implementation is often finding some workarounds (necessitated by the complicated codebase and the requirements of the new feature) by creating parallel implementations to the existing code instead of integrate the changes with the least changes possible. But I also realize that this was hardly possible in the given time considering the complexity of the task and the little information you had to start with :)
Nevertheless, I looked at the changes to the existing pymob codebase and tried to identify the sites where changes could be possible.
I have also quickly looked at the UDE.py file. Great work (and evidently not easy) to find a way to initialize an MLP with custom biases and weights. I still wonder whether it is possible to craft a UDE parameter, which contains the UDE structure, weights and biases and from there, in the backend somewhere constructs an MLP, which is placed at the a specific site in the model; maybe we can discuss this at some point.
I have not looked at the optax backend yet, because most of it is completely new to me and I didn't have time for an in depth review.
As discussed the refactor which I think is still necessary, has lowest priority for your thesis.
Before you go on to implement my change requests, I think it is my turn to see if what I have in mind is at all possible.
Again, thanks for implementing this complicated feature and I think for what you had to start with and the complexities of both the framework and the new feature you did a great job!
Flo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file needs to be deleted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be leaving some comments here to preserve my current knowledge for later when we deal with the refactoring.
And you are completely right, this has to be deleted.
| from . import prob | ||
| from . import sim | ||
|
|
||
| __version__ = "1.0.0" No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be useful to make this version 0.1.0 (according to semantic versioning, 1.0.0 usually indicates a package that has matured a bit).
| name = "lotka_volterra_UDE_case_study" | ||
| version = "1.0.0" | ||
| authors = [ | ||
| { name="Florian Schunck", email="[email protected]" }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be you :-)
| authors = [ | ||
| { name="Florian Schunck", email="[email protected]" }, | ||
| ] | ||
| description = "Lotka Volterra Predator-Prey case study" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add the UDE component here
| "Homepage" = "https://github.com/flo-schu/lotka_volterra_case_study" | ||
| "Issue Tracker" = "https://github.com/flo-schu/lotka_volterra_case_study/issues" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update once the UDE case study is in a separate repo
pymob/sim/config.py
Outdated
| class Optax(PymobModel): | ||
| model_config = ConfigDict(validate_assignment=True, extra="ignore") | ||
|
|
||
| UDE_parameters: Annotated[Modelparameters, BeforeValidator(string_to_modelparams), serialize_modelparams_to_string] = Modelparameters() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is opening a second implementation line for defining model parameters. The optax section should only be used for defining how the algorithm operates not for the model description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason the UDE parameters were separated from the "normal" model parameters was that Pymob expects the model parameters to appear as input variables of the model which is not the case for the current UDE model formulation.
If we make the change suggested by you above, this distinction might not be necessary anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done ✅
The model parameters in the config file were never the problem, sim.model_parameters["parameters"] was. So the UDE_parameters don't exist anymore.
pymob/sim/config.py
Outdated
| UDE_parameters: Annotated[Modelparameters, BeforeValidator(string_to_modelparams), serialize_modelparams_to_string] = Modelparameters() | ||
| MLP_weight_dist: OptionRV = to_rv("normal()") | ||
| MLP_bias_dist: OptionRV = to_rv("normal()") | ||
| loss_function: Callable = lambda x_obs, x_pred: (x_obs - x_pred)**2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think callables can be defined from a file, but you noticed that and excluded the loss function in the save method. This is a bit tricky. It is not a good form that the loss function can only be defined in script mode. This makes it unusable for the workflow config -> init -> execution, which is used for the commandline execution. Either parse the loss function from a string to a callable, but this is difficult. Or implement some default loss functions and choose from them by name. Or define them in the Model, which is possibly the most straightforward option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer the last option. It was a conscious decision to let the user define their own loss function because it is impossible to cover everything by calling them by some pre-defined names. As far as I know, the loss function is a pretty powerful tool to improve training. It seems to be absolutely common to include some prior knowledge about the expected outcome in the loss function. I actually did that as well in my early exploration by including a barrier function to prevent the state variables from getting too close to zero. So I really want to give the users maximum freedom in their choice of a loss function. If you say that the model class is a good spot for that, then we can easily have the users define it there. We can also set a default loss function by defining it in the UDEBase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done ✅
| mode="json", | ||
| exclude_none=True, | ||
| exclude={"case_study": {"output_path", "data_path", "root", "init_root", "default_settings_path"}} | ||
| exclude={"case_study": {"output_path", "data_path", "root", "init_root", "default_settings_path"}, "inference_optax": {"loss_function"}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove as mentioned earlier.
|
|
||
| model_solver = self.model | ||
|
|
||
| equinox = import_optional_dependency( | ||
| "equinox", errors="ignore" | ||
| ) | ||
| if equinox is not None: | ||
| from pymob.solvers.diffrax import UDESolver | ||
| import equinox as eqx | ||
| if solver == UDESolver: | ||
| model_params, model_solver = eqx.partition(self.model, eqx.is_array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is okay for now, but I would prefer if we found a more general solution, where the model, which is returned from the model class is hashable (static) by definition. I don't know really how this would work though; let's keep it like that for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do some research on that but I honestly think that there is no way to make the model class hashable inside of JIT. As far as I know, the Jax arrays are absolutely essential to equinox' calculation of gradients. And I still have found no way to hash them when JITted.
On the other hand, normally, eqx.filter_jit should treat every Jax and NumPy array as dynamic and everything else as static. In this case, the JAX arrays would not have to be hashable. And outside of JIT, hashing them is possible by transforming them to tuples, so the hash call in the evaluator post-init method should work if a hash function exists. But for some reason, the filtering done by eqx.filter_jit doesn't seem to work. Edit: I'm starting to suspect that this is due to JaxSolver not being a PyTree. Do you think this is plausible? Anyway, fixing this would probably require bigger changes than I can still make before generating my results so I'm concentrating on the idata thing now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to have equinox.filter_jit also in the jaxsolver if it works with the other tests. In that case, the UDESolver can be reduced quite a bit, because most functions can be reused, correct? I think this would be good, so that also here we don't maintain duplicate implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the UDESolver basically only contains copies of some of JaxSolver's functions with two changes being made:
eqx.filter_jitwas used instead ofjax.jit. If we usedeqx.filter_jitfor theJaxSolver, equinox would not be an optional dependency anymore, though.- The partitioned model is passed on to other functions or combined inside the functions by using
eqx.combine. If we can somehow get rid of the need to partition the model, theUDESolvermight not be necessary at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit:
3. The UDESolver does not refer to the "mechanistic" model parameters, as those are stored as an attribute of the model. Also, the model always expects an x_in input even if there is no input data which necessitated another small change. But these things will resolve themselves if we make model classes the standard way of defining a class in Pymob.
…models, and deleted some redundant code
The relevant changes are the ones to...
The rest is just there for testing purposes. As I said before, I intend to write a proper testing file but this will not be ready before your vacation. I still wanted to create a pull request even in this state to at least show you the most relevant changes.