Skip to content

Faster jnp.trapezoid when dx is a scalar #34915

@jurasic-pf

Description

@jurasic-pf

TLDR

There's a faster and just as accurate way to compute jnp.trapezoid when dx is scalar.
Is there any reason against specializing in that case?

Motivation

For uniform x (when x=None and dx is passed), the cumulative trapezoid rule boils down to a sum + endpoint handling. This can be implemented more efficiently than the general trapezoid rule.
jnp.trapezoid is used in lots of scientific code, and this change translates to measureable improvements in overall execution time of simulations on both CPU and GPU (in one of our examples 1.42s -> 1.18s), in micro benchmarks of just jnp.trapezoid it can be up to 5x faster.

Current:

0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)

Suggested:

if dx_array.ndim == 0:
    dx_array * (y_arr.sum(-1) - 0.5 * (y_arr[..., 0] + y_arr[..., -1]))

Please:

  • Check for duplicate requests.
  • Describe your goal, and if possible provide a code snippet with a motivating example.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions