-
-
Notifications
You must be signed in to change notification settings - Fork 167
Description
I'm solving an ODE system for viral dynamics using Kvaerno4 with a PID controller. The state variables need to stay non-negative....
I searched the docs and issues but didn't find anything about handling non-negativity constraints. Currently forcing states to be non-negative by clamping them to zero:
def ode_system(t, state, params):
state = jnp.maximum(state, 0.0) # <-- this is my current workaround
# ..... compute derivatives ....
return derivativesBut this is generally a bad approach since it interferes with error control and introduces discontinuities, which is why I'm here asking for guidance;
Looking at how MATLAB's ode15s handles this - it supports non-negativity through odeset with the NonNegative option. The solver does two things: first, it wraps the ODE function to modify derivatives, and second, it incorporates constraint violations into error estimation for step acceptance/rejection. Inspecting odenonnegative.m reveals the derivative modification approach:
function yp = local_odeFcn_nonnegative(idxNonNegative, ode, t, y, varargin)
yp = feval(ode, t, y, varargin{:});
ndx = idxNonNegative(find(y(idxNonNegative) <= 0));
yp(ndx) = max(yp(ndx), 0); % <-- here
endThen during the main integration loop, after computing a candidate solution step, the solver checks if any constrained variables went negative and computes an additional error term that can trigger step rejection.
Is there a recommended pattern for this in Diffrax? I'm still fairly new to JAX and Diffrax, so don't have enough understanding of the internals to implement something similar myself. Would appreciate any pointers or if there's an existing approach I'm missing