Skip to content

Commit b91138f

Browse files
philipwijesinghepatrick-kidger
authored andcommitted
avoids accumulation of float precision errors in dt
this solution makes sure that dt is reset to the desired dtmin value if the previous step was at dtmin and dt is unchanged (factor=1) if we do not reset dt then the recalculation of prev_dt = t1 - t0 will keep accumulating float precision errors with potential to drift away from the desired dtmin until a step that warrant a relaxation of step size (factor>1) these errors are likely to be minor, but i believe this is the intended behaviour
1 parent 02d6b8a commit b91138f

File tree

1 file changed

+9
-7
lines changed
  • diffrax/_step_size_controller

1 file changed

+9
-7
lines changed

diffrax/_step_size_controller/pid.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def intermediate(carry):
8181
return jnp.minimum(100 * h0, h1)
8282

8383

84-
# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, keep_next_step)
84+
# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, at_dtmin)
8585
_PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike]
8686

8787

@@ -470,7 +470,7 @@ def adapt_step_size(
470470
(
471471
prev_inv_scaled_error,
472472
prev_prev_inv_scaled_error,
473-
keep_next_step,
473+
at_dtmin,
474474
) = controller_state
475475
error_order = self._get_error_order(error_order)
476476
prev_dt = t1 - t0
@@ -493,7 +493,7 @@ def _scale(_y0, _y1_candidate, _y_error):
493493
keep_step = scaled_error < 1
494494
# Automatically keep the step if it was at dtmin.
495495
if self.dtmin is not None:
496-
keep_step = keep_step | keep_next_step
496+
keep_step = keep_step | at_dtmin
497497
# Make sure it's not a Python scalar and thus getting a ZeroDivisionError.
498498
inv_scaled_error = 1 / jnp.asarray(scaled_error)
499499
inv_scaled_error = lax.stop_gradient(
@@ -547,9 +547,11 @@ def _scale(_y0, _y1_candidate, _y_error):
547547
if self.dtmin is not None:
548548
if not self.force_dtmin:
549549
result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result)
550-
# flag next step to be kept if dtmin is reached
551-
# or if it was reached previously and dt is unchanged
552-
keep_next_step = (dt <= self.dtmin) | (keep_next_step & (factor == 1))
550+
# if we are already at dtmin and dt is unchanged (factor == 1),
551+
# reset dt to dtmin to avoid accumulating float precision errors
552+
dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt)
553+
# this flags the next loop to accept step
554+
at_dtmin = dt <= self.dtmin
553555
dt = jnp.maximum(dt, self.dtmin)
554556

555557
next_t0 = jnp.where(keep_step, t1, t0)
@@ -559,7 +561,7 @@ def _scale(_y0, _y1_candidate, _y_error):
559561
prev_inv_scaled_error = jnp.where(
560562
keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error
561563
)
562-
controller_state = inv_scaled_error, prev_inv_scaled_error, keep_next_step
564+
controller_state = inv_scaled_error, prev_inv_scaled_error, at_dtmin
563565
# made_jump is handled by ClipStepSizeController, so we automatically set it to
564566
# False
565567
return keep_step, next_t0, next_t1, False, controller_state, result

0 commit comments

Comments
 (0)