Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 54 additions & 70 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,7 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
jtu.tree_structure((0, 0)),
event_values__mask,
)
had_event = False
event_mask_leaves = []
for event_mask_i in jtu.tree_leaves(event_mask):
event_mask_leaves.append(event_mask_i & jnp.invert(had_event))
had_event = event_mask_i | had_event
event_mask = jtu.tree_unflatten(event_structure, event_mask_leaves)
had_event = jnp.any(jnp.stack(jtu.tree_leaves(event_mask), axis=0))
result = RESULTS.where(
had_event,
RESULTS.event_occurred,
Expand Down Expand Up @@ -653,6 +648,7 @@ def body_fun(state):
if event is None or event.root_finder is None:
tfinal = final_state.tprev
yfinal = final_state.y
first_event_mask = final_state.event_mask
else:
# If we're on this branch, it means that an event may have triggered, and now we
# may need to do a root find, in order to locate the event time.
Expand All @@ -663,19 +659,19 @@ def body_fun(state):
event_happened = jnp.max(float_mask) > 0.0

def _root_find():
_interpolator = solver.interpolation_cls(
interp = solver.interpolation_cls(
t0=final_state.event_tprev,
t1=final_state.event_tnext,
**final_state.event_dense_info,
)

def _to_root_find(_t, _):
_distance_from_t_end = final_state.event_tnext - _t
flat_fns, fn_tree = jtu.tree_flatten(event.cond_fn)
flat_masks, _ = jtu.tree_flatten(event_mask)

def _call_real(_event_mask_i, _cond_fn_i):
def _call_real_impl():
# First evaluate the triggered event.
_y = _interpolator.evaluate(_t)
def _call_real(_event_mask_i, _cond_fn_i):
def _find():
def f(_t, _):
_y = interp.evaluate(_t)
_value = _cond_fn_i(
t=_t,
y=_y,
Expand All @@ -689,67 +685,56 @@ def _call_real_impl():
stepsize_controller=stepsize_controller,
max_steps=max_steps,
)
# Second: if this is a boolean event, then normalise to a
# floating point number by having the root occur at the end of
# the last step, i.e. `event_tnext`.
_value_dtype = jnp.result_type(_value)
if jnp.issubdtype(_value_dtype, jnp.bool_):
_value = _distance_from_t_end
else:
assert jnp.issubdtype(_value_dtype, jnp.floating)
return _value

# Only the triggered event actually gets to the decide what time the
# event occurs; everything else is zeroed out to automatically give
# a root.
#
# We allow this `lax.cond` to be inefficiently transformed into a
# `lax.select` when `_event_mask_i` is batched. There isn't any way
# to avoid this, I think.
_value = lax.cond(_event_mask_i, _call_real_impl, lambda: 0.0)

# Third: if no events triggered at all, then have the root occur at
# the end of the last step (which will be the `t1` of the overall
# solve).
_value = jnp.where(event_happened, _value, _distance_from_t_end)
return _value

return jtu.tree_map(
_call_real,
event_mask,
event.cond_fn,
)
return (
(final_state.event_tnext - _t)
if jnp.issubdtype(_value.dtype, jnp.bool_)
else _value
)

opts = {
"lower": final_state.event_tprev,
"upper": final_state.event_tnext,
}
res = optx.root_find(
f,
event.root_finder,
y0=final_state.event_tnext,
options=opts,
throw=False,
)
return res.value

_options = {
"lower": final_state.event_tprev,
"upper": final_state.event_tnext,
}
_event_root_find = optx.root_find(
_to_root_find,
event.root_finder,
y0=final_state.event_tnext,
options=_options,
throw=False,
return lax.cond(_event_mask_i, _find, lambda: jnp.inf)

candidates = jnp.stack(
[_call_real(m, fn) for m, fn in zip(flat_masks, flat_fns)]
)
_tfinal = _event_root_find.value
# TODO: we might need to change the way we evaluate `_yfinal` in order to
# get more accurate derivatives?
_yfinal = _interpolator.evaluate(_tfinal)
_result = RESULTS.where(
_event_root_find.result == optx.RESULTS.successful,

t_event = jnp.min(candidates)
t_event = jnp.where(jnp.isfinite(t_event), t_event, final_state.event_tnext)

y_event = interp.evaluate(t_event)

first_idx = jnp.argmin(candidates)
first_mask_arr = jnp.arange(candidates.shape[0]) == first_idx
first_event_mask = jtu.tree_unflatten(fn_tree, list(first_mask_arr))

new_result = RESULTS.where(
jnp.any(jnp.stack(flat_masks)),
RESULTS.event_occurred,
result,
RESULTS.promote(_event_root_find.result),
)
return _tfinal, _yfinal, _result

return t_event, y_event, new_result, first_event_mask

# Fastpath: if no event happened anywhere at all, then skip the root-find
# altogether.
# Note that `_root_find` might still be called on batch elements which did not
# have an event, so we still need to access `event_happened` inside of it.
tfinal, yfinal, result = lax.cond(
tfinal, yfinal, result, first_event_mask = lax.cond(
eqxi.unvmap_any(event_happened),
_root_find,
lambda: (final_state.tprev, final_state.y, result),
lambda: (final_state.tprev, final_state.y, result, final_state.event_mask),
)

# We delete all the saved values after the event time.
Expand Down Expand Up @@ -824,9 +809,13 @@ def _save_t1(subsaveat, save_state):
final_state = eqx.tree_at(
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
)

final_state = _handle_static(final_state)
result = RESULTS.where(cond_fun(final_state), RESULTS.max_steps_reached, result)
aux_stats = dict() # TODO: put something in here?

# override event mask with first found event
final_state = eqx.tree_at(lambda s: s.event_mask, final_state, first_event_mask)
return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats


Expand Down Expand Up @@ -1339,18 +1328,13 @@ def _outer_cond_fn(cond_fn_i):
jtu.tree_structure((0, 0)),
event_values__mask,
)
had_event = False
event_mask_leaves = []
for event_mask_i in jtu.tree_leaves(event_mask):
event_mask_leaves.append(event_mask_i & jnp.invert(had_event))
had_event = event_mask_i | had_event
event_mask = jtu.tree_unflatten(event_structure, event_mask_leaves)
had_event = jnp.any(jnp.stack(jtu.tree_leaves(event_mask), axis=0))
result = RESULTS.where(
had_event,
RESULTS.event_occurred,
result,
)
del had_event, event_structure, event_mask_leaves, event_values__mask
del had_event, event_structure, event_values__mask

# Initialise state
init_state = State(
Expand Down
Loading