|
1 | 1 | from typing import Literal |
2 | 2 | import casadi as _cas |
3 | 3 | import numpy as _onp |
4 | | -from aerosandbox.numpy.array import length, concatenate |
| 4 | +from aerosandbox.numpy.array import length, concatenate, asarray |
| 5 | +from aerosandbox.numpy.typing import ArrayLike |
5 | 6 |
|
6 | 7 |
|
7 | 8 | def integrate_discrete_intervals( |
8 | | - f: _onp.ndarray | _cas.MX, |
9 | | - x: _onp.ndarray | _cas.MX | None = None, |
| 9 | + f: ArrayLike, |
| 10 | + x: ArrayLike | None = None, |
10 | 11 | multiply_by_dx: bool = True, |
11 | 12 | method: Literal[ |
12 | 13 | "forward_euler", |
@@ -50,10 +51,15 @@ def integrate_discrete_intervals( |
50 | 51 | - "periodic" |
51 | 52 |
|
52 | 53 | """ |
| 54 | + # Convert inputs to arrays for subscripting |
| 55 | + f = asarray(f) |
| 56 | + |
53 | 57 | # Determine if an x-array was specified, and calculate dx. |
54 | 58 | x_is_specified = x is not None |
55 | 59 | if not x_is_specified: |
56 | 60 | x = _onp.arange(length(f)) |
| 61 | + else: |
| 62 | + x = asarray(x) |
57 | 63 |
|
58 | 64 | dx = x[1:] - x[:-1] |
59 | 65 |
|
@@ -266,8 +272,8 @@ def integrate_discrete_intervals( |
266 | 272 |
|
267 | 273 |
|
268 | 274 | def integrate_discrete_squared_curvature( |
269 | | - f: _onp.ndarray | _cas.MX, |
270 | | - x: _onp.ndarray | _cas.MX | None = None, |
| 275 | + f: ArrayLike, |
| 276 | + x: ArrayLike | None = None, |
271 | 277 | method: Literal[ |
272 | 278 | "cubic", "simpson", "hybrid_simpson_cubic" |
273 | 279 | ] = "hybrid_simpson_cubic", |
@@ -325,10 +331,15 @@ def integrate_discrete_squared_curvature( |
325 | 331 | well as a regularization strategy. (It is still convergent to the true value in the high-sample-rate limit.) |
326 | 332 |
|
327 | 333 | """ |
| 334 | + # Convert inputs to arrays for subscripting |
| 335 | + f = asarray(f) |
| 336 | + |
328 | 337 | # Determine if an x-array was specified, and calculate dx. |
329 | 338 | x_is_specified = x is not None |
330 | 339 | if not x_is_specified: |
331 | 340 | x = _onp.arange(length(f)) |
| 341 | + else: |
| 342 | + x = asarray(x) |
332 | 343 |
|
333 | 344 | if method in ["cubic", "cubic_spline"]: |
334 | 345 | x1 = x[:-3] |
|
0 commit comments