-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
TLDR
There's a faster and just as accurate way to compute jnp.trapezoid when dx is scalar.
Is there any reason against specializing in that case?
Motivation
For uniform x (when x=None and dx is passed), the cumulative trapezoid rule boils down to a sum + endpoint handling. This can be implemented more efficiently than the general trapezoid rule.
jnp.trapezoid is used in lots of scientific code, and this change translates to measureable improvements in overall execution time of simulations on both CPU and GPU (in one of our examples 1.42s -> 1.18s), in micro benchmarks of just jnp.trapezoid it can be up to 5x faster.
Current:
0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)Suggested:
if dx_array.ndim == 0:
dx_array * (y_arr.sum(-1) - 0.5 * (y_arr[..., 0] + y_arr[..., -1]))Please:
- Check for duplicate requests.
- Describe your goal, and if possible provide a code snippet with a motivating example.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request