Skip to content

Commit c0de217

Browse files
committed
improve error handling in bopdmd fit_econ
1 parent fbe8032 commit c0de217

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

pydmd/bopdmd.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,9 +1489,19 @@ def fit_econ(self, s, V, t):
14891489
self._reset()
14901490
self._time = np.array(t).squeeze()
14911491

1492-
if self._proj_basis is None:
1493-
msg = "proj_basis must be provided when using fit_econ."
1492+
if self._proj_basis is None or not self._use_proj:
1493+
msg = """
1494+
proj_basis must be provided when using fit_econ,
1495+
and use_proj must be set to True.
1496+
"""
14941497
raise ValueError(msg)
1498+
elif (
1499+
not isinstance(self._proj_basis, np.ndarray)
1500+
or self._proj_basis.ndim != 2
1501+
or self._proj_basis.shape[1] != self._svd_rank
1502+
):
1503+
msg = "proj_basis must be a 2D np.ndarray with {} columns."
1504+
raise ValueError(msg.format(self._svd_rank))
14951505

14961506
# Check that input time vector is one-dimensional.
14971507
if self._time.ndim > 1:
@@ -1502,27 +1512,24 @@ def fit_econ(self, s, V, t):
15021512
if (
15031513
not isinstance(s, np.ndarray)
15041514
or s.ndim != 1
1505-
or len(s) != self._proj_basis.shape[1]
1515+
or len(s) != self._svd_rank
15061516
):
15071517
msg = """
1508-
s must be a 1D numpy.ndarray with length equal to the number
1509-
of columns in the projection basis (U).
1518+
s must be a 1D numpy.ndarray of length {}.
15101519
"""
1511-
raise ValueError(msg)
1520+
raise ValueError(msg.format(self._svd_rank))
15121521

15131522
# Check that V is a 2D numpy.ndarray.
15141523
if (
15151524
not isinstance(V, np.ndarray)
15161525
or V.ndim != 2
1517-
or V.shape[0] != self._proj_basis.shape[1]
1526+
or V.shape[0] != self._svd_rank
15181527
or V.shape[1] != len(self._time)
15191528
):
15201529
msg = """
1521-
V must be a 2D numpy.ndarray with the same number of rows
1522-
as the projection basis (U) has columns and the same number of columns
1523-
as the length of the time vector t.
1530+
V must be a 2D numpy.ndarray with shape ({}, {}).
15241531
"""
1525-
raise ValueError(msg)
1532+
raise ValueError(msg.format(self._svd_rank, len(self._time)))
15261533

15271534
# Set/check the initial guess for the continuous-time DMD eigenvalues.
15281535
if self._init_alpha is None:

0 commit comments

Comments
 (0)