Skip to content

Commit a373497

Browse files
committed
verified type changes
1 parent 5964b53 commit a373497

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

aerosandbox/optimization/opti.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from pathlib import Path
12
from typing import Callable, Any, Literal, Sequence
23
import json
34
import casadi as cas
45
import aerosandbox.numpy as np
56
from aerosandbox.tools import inspect_tools
67
from sortedcontainers import SortedDict
78
import copy
9+
from aerosandbox.numpy.typing import ArrayLike
810

911

1012
class Opti(cas.Opti):
@@ -71,14 +73,14 @@ def __init__(
7173

7274
def variable(
7375
self,
74-
init_guess: float | np.ndarray | None = None,
75-
n_vars: int = None,
76-
scale: float = None,
76+
init_guess: float | int | np.ndarray | None = None,
77+
n_vars: int | None = None,
78+
scale: float | int | np.ndarray | None = None,
7779
freeze: bool = False,
7880
log_transform: bool = False,
7981
category: str = "Uncategorized",
80-
lower_bound: float = None,
81-
upper_bound: float = None,
82+
lower_bound: float | int | np.ndarray | None = None,
83+
upper_bound: float | int | np.ndarray | None = None,
8284
_stacklevel: int = 1,
8385
) -> cas.MX | float | np.ndarray:
8486
"""
@@ -369,7 +371,7 @@ def subject_to(
369371
self,
370372
constraint: cas.MX | bool | list, # TODO add scale
371373
_stacklevel: int = 1,
372-
) -> cas.MX | None | list[cas.MX]:
374+
) -> cas.MX | None | list[cas.MX | None]:
373375
"""
374376
Initialize a new equality or inequality constraint(s).
375377
@@ -1007,7 +1009,7 @@ def derivative_of(
10071009
method: str = "trapezoidal",
10081010
explicit: bool = False, # TODO implement explicit
10091011
_stacklevel: int = 1,
1010-
) -> cas.MX:
1012+
) -> cas.MX | float | np.ndarray:
10111013
"""
10121014
Returns a quantity that is either defined or constrained to be a derivative of an existing variable.
10131015
@@ -1138,9 +1140,9 @@ def derivative_of(
11381140

11391141
def constrain_derivative(
11401142
self,
1141-
derivative: cas.MX,
1142-
variable: cas.MX,
1143-
with_respect_to: np.ndarray | cas.MX,
1143+
derivative: ArrayLike,
1144+
variable: ArrayLike,
1145+
with_respect_to: ArrayLike,
11441146
method: str = "trapezoidal",
11451147
_stacklevel: int = 1,
11461148
) -> cas.MX | None | list[cas.MX]:

0 commit comments

Comments
 (0)