Skip to content

Commit 0fe12ba

Browse files
committed
Validated type-hint changes
1 parent 552a886 commit 0fe12ba

File tree

9 files changed

+38
-12
lines changed

9 files changed

+38
-12
lines changed

aerosandbox/numpy/arithmetic_dyadic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _make_casadi_types_broadcastable(
2929
tuple[Array, Array]
3030
Both arrays tiled to have the same (broadcast) shape.
3131
"""
32-
def shape_2D(object: float | int | Iterable | _onp.ndarray) -> tuple:
32+
def shape_2D(object: Vectorizable) -> tuple:
3333
shape = _onp.shape(object)
3434
if len(shape) == 0:
3535
return (1, 1)

aerosandbox/numpy/arithmetic_monadic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as _onp
77
import casadi as _cas
88
from aerosandbox.numpy.determine_type import is_casadi_type
9+
from aerosandbox.numpy.array import asarray
910
from aerosandbox.numpy.typing import Array, ArrayLike, Scalar
1011

1112

@@ -80,6 +81,7 @@ def mean(x: ArrayLike, axis: int | None = None) -> Scalar | Array:
8081
return _onp.mean(x, axis=axis)
8182

8283
else:
84+
x = asarray(x) # Ensure x is Array for .shape access
8385
if axis == 0:
8486
return sum(x, axis=0) / x.shape[0]
8587
elif axis == 1:

aerosandbox/numpy/array.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def length(array: ArrayLike | Scalar) -> int:
320320
return 1
321321

322322
else:
323+
array = asarray(array) # Ensure array is Array for .shape access
323324
if array.shape[0] != 1:
324325
return array.shape[0]
325326
else:
@@ -359,6 +360,7 @@ def diag(v: ArrayLike, k: int = 0) -> Array:
359360
return _onp.diag(v, k=k)
360361

361362
else:
363+
v = asarray(v) # Ensure v is Array for .shape access
362364
if 1 in v.shape: # If v is a 1D array, construct a diagonal matrix
363365
if v.shape[0] == 1:
364366
v = v.T
@@ -421,6 +423,7 @@ def roll(a: ArrayLike, shift: int | tuple[int, ...], axis: int | tuple[int, ...]
421423
if not is_casadi_type(a, recursive=False):
422424
return _onp.roll(a, shift, axis=axis)
423425
else:
426+
a = asarray(a) # Ensure a is Array for .shape access
424427
if axis is None:
425428
a_flat = reshape(a, -1)
426429
result = roll(a_flat, shift, axis=0)
@@ -480,6 +483,7 @@ def max(a: ArrayLike, axis: int | None = None) -> Scalar | Array:
480483
)
481484

482485
else:
486+
a = asarray(a) # Ensure a is Array for .shape access
483487
if axis is None:
484488
return _cas.mmax(a)
485489

@@ -532,6 +536,7 @@ def min(a: ArrayLike, axis: int | None = None) -> Scalar | Array:
532536
)
533537

534538
else:
539+
a = asarray(a) # Ensure a is Array for .shape access
535540
if axis is None:
536541
return _cas.mmin(a)
537542

aerosandbox/numpy/calculus.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,15 @@ def trapz(x: ArrayLike, modify_endpoints: bool = False) -> Array:
290290
integrate_discrete_intervals : Preferred alternative.
291291
"""
292292
import warnings
293+
from aerosandbox.numpy.array import asarray
293294

294295
warnings.warn(
295296
"trapz() will eventually be deprecated, since NumPy plans to remove it in the upcoming NumPy 2.0 release (2024). \n"
296297
'For discrete intervals, use asb.numpy.integrate_discrete_intervals(f, method="trapz") instead.',
297298
PendingDeprecationWarning,
298299
)
299300

301+
x = asarray(x) # Convert to Array for subscripting
300302
integral = (x[1:] + x[:-1]) / 2
301303
if modify_endpoints:
302304
integral[0] = integral[0] + x[0] * 0.5

aerosandbox/numpy/conditionals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010

1111

1212
def where(
13-
condition: Array,
13+
condition: Array | bool | _onp.bool_,
1414
value_if_true: Vectorizable,
1515
value_if_false: Vectorizable,
1616
) -> Array:
1717
"""Return elements chosen from x or y depending on condition.
1818
1919
Parameters
2020
----------
21-
condition : Array
21+
condition : Array | bool
2222
Where True, yield ``value_if_true``, otherwise yield ``value_if_false``.
2323
value_if_true : Vectorizable
2424
Values from which to choose where ``condition`` is True.

aerosandbox/numpy/integrate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def solve_ivp(
251251
* `t_variable` is a CasADi variable (cas.MX)
252252
* `y_variables` is a CasADi variable (cas.MX), possibly a vector of variables
253253
"""
254+
assert y_variables is not None # Type narrowing for type checker
254255

255256
t0 = t_span[0]
256257
tf = t_span[1]

aerosandbox/numpy/integrate_discrete.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Literal
22
import casadi as _cas
33
import numpy as _onp
4-
from aerosandbox.numpy.array import length, concatenate
4+
from aerosandbox.numpy.array import length, concatenate, asarray
5+
from aerosandbox.numpy.typing import ArrayLike
56

67

78
def integrate_discrete_intervals(
8-
f: _onp.ndarray | _cas.MX,
9-
x: _onp.ndarray | _cas.MX | None = None,
9+
f: ArrayLike,
10+
x: ArrayLike | None = None,
1011
multiply_by_dx: bool = True,
1112
method: Literal[
1213
"forward_euler",
@@ -50,10 +51,15 @@ def integrate_discrete_intervals(
5051
- "periodic"
5152
5253
"""
54+
# Convert inputs to arrays for subscripting
55+
f = asarray(f)
56+
5357
# Determine if an x-array was specified, and calculate dx.
5458
x_is_specified = x is not None
5559
if not x_is_specified:
5660
x = _onp.arange(length(f))
61+
else:
62+
x = asarray(x)
5763

5864
dx = x[1:] - x[:-1]
5965

@@ -266,8 +272,8 @@ def integrate_discrete_intervals(
266272

267273

268274
def integrate_discrete_squared_curvature(
269-
f: _onp.ndarray | _cas.MX,
270-
x: _onp.ndarray | _cas.MX | None = None,
275+
f: ArrayLike,
276+
x: ArrayLike | None = None,
271277
method: Literal[
272278
"cubic", "simpson", "hybrid_simpson_cubic"
273279
] = "hybrid_simpson_cubic",
@@ -325,10 +331,15 @@ def integrate_discrete_squared_curvature(
325331
well as a regularization strategy. (It is still convergent to the true value in the high-sample-rate limit.)
326332
327333
"""
334+
# Convert inputs to arrays for subscripting
335+
f = asarray(f)
336+
328337
# Determine if an x-array was specified, and calculate dx.
329338
x_is_specified = x is not None
330339
if not x_is_specified:
331340
x = _onp.arange(length(f))
341+
else:
342+
x = asarray(x)
332343

333344
if method in ["cubic", "cubic_spline"]:
334345
x1 = x[:-3]

aerosandbox/numpy/linalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import casadi as _cas
88
from aerosandbox.numpy.arithmetic_monadic import sum, abs
99
from aerosandbox.numpy.determine_type import is_casadi_type
10+
from aerosandbox.numpy.array import asarray
1011
from aerosandbox.numpy.typing import ArrayLike, Array, Scalar, VectorLike
1112
from numpy.linalg import *
1213

@@ -71,6 +72,7 @@ def outer(x: VectorLike, y: VectorLike, manual: bool = False) -> Array:
7172
return _onp.outer(x, y)
7273

7374
else:
75+
y = asarray(y) # Ensure y is Array for .shape access
7476
if len(y.shape) == 1: # Force y to be transposable if it's not.
7577
y = _onp.expand_dims(y, 1)
7678
return x @ y.T
@@ -214,6 +216,7 @@ def norm(
214216
return _onp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
215217

216218
else:
219+
x = asarray(x) # Ensure x is Array for .shape access
217220
# Figure out which axis, if any, to take a vector norm about.
218221
if axis is not None:
219222
if not (axis == 0 or axis == 1 or axis == -1):

aerosandbox/numpy/typing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,18 @@
122122
# INPUT TYPES - Permissive types for function parameters (hybrid)
123123
# =============================================================================
124124

125-
VectorLike = Sequence[float] | Sequence[int] | _onp.ndarray | CasADiType
126-
"""Permissive input for vector parameters: sequence, ndarray, or CasADi.
125+
VectorLike = int | float | _onp.integer | _onp.floating | Sequence[int | float] | _onp.ndarray | CasADiType
126+
"""Permissive input for vector parameters: scalar, sequence, ndarray, or CasADi.
127127
128128
Use for function INPUTS that will be converted to Vector internally.
129+
Includes scalars so that `Scalar | Array` return values can be passed directly.
129130
"""
130131

131-
ArrayLike = Sequence[float] | Sequence[int] | _onp.ndarray | CasADiType
132-
"""Permissive input for array parameters: sequence, ndarray, or CasADi.
132+
ArrayLike = int | float | _onp.integer | _onp.floating | Sequence[int | float] | _onp.ndarray | CasADiType
133+
"""Permissive input for array parameters: scalar, sequence, ndarray, or CasADi.
133134
134135
Use for function INPUTS that will be converted to Array internally.
136+
Includes scalars so that `Scalar | Array` return values can be passed directly.
135137
"""
136138

137139
PointLike = Sequence[float] | _onp.ndarray | CasADiType

0 commit comments

Comments
 (0)