Skip to content

How to enforce non-negativity constraints? #697

@Mycroft-47

Description

@Mycroft-47

I'm solving an ODE system for viral dynamics using Kvaerno4 with a PID controller. The state variables need to stay non-negative....

I searched the docs and issues but didn't find anything about handling non-negativity constraints. Currently forcing states to be non-negative by clamping them to zero:

def ode_system(t, state, params):
    state = jnp.maximum(state, 0.0) # <-- this is my current workaround
    # ..... compute derivatives ....
    return derivatives

But this is generally a bad approach since it interferes with error control and introduces discontinuities, which is why I'm here asking for guidance;

Looking at how MATLAB's ode15s handles this - it supports non-negativity through odeset with the NonNegative option. The solver does two things: first, it wraps the ODE function to modify derivatives, and second, it incorporates constraint violations into error estimation for step acceptance/rejection. Inspecting odenonnegative.m reveals the derivative modification approach:

function yp = local_odeFcn_nonnegative(idxNonNegative, ode, t, y, varargin)
    yp = feval(ode, t, y, varargin{:});
    ndx = idxNonNegative(find(y(idxNonNegative) <= 0));
    yp(ndx) = max(yp(ndx), 0); % <-- here
end

Then during the main integration loop, after computing a candidate solution step, the solver checks if any constrained variables went negative and computes an additional error term that can trigger step rejection.

Is there a recommended pattern for this in Diffrax? I'm still fairly new to JAX and Diffrax, so don't have enough understanding of the internals to implement something similar myself. Would appreciate any pointers or if there's an existing approach I'm missing

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions