Skip to content

Commit 66563fb

Browse files
committed
Improved fit checking
1 parent d2fbcc0 commit 66563fb

File tree

2 files changed

+77
-12
lines changed

2 files changed

+77
-12
lines changed

pydmd/bopdmd.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ def A(self):
214214
"Set parameter compute_A = True to compute A."
215215
)
216216
raise ValueError(msg)
217-
if self._A is None:
218-
raise ValueError("You need to call fit before")
219217
return self._A
220218

221219
@property
@@ -1126,7 +1124,7 @@ def init_alpha(self):
11261124
"fit() hasn't been called "
11271125
"and no initial value for alpha has been given."
11281126
)
1129-
raise RuntimeError(msg)
1127+
raise ValueError(msg)
11301128
return self._init_alpha
11311129

11321130
@init_alpha.setter
@@ -1148,7 +1146,7 @@ def proj_basis(self):
11481146
"fit() hasn't been called "
11491147
"and no projection basis has been given."
11501148
)
1151-
raise RuntimeError(msg)
1149+
raise ValueError(msg)
11521150
return self._proj_basis
11531151

11541152
@proj_basis.setter
@@ -1183,8 +1181,9 @@ def time(self):
11831181
:return: the vector that contains the original time points.
11841182
:rtype: numpy.ndarray
11851183
"""
1186-
if self._time is None:
1187-
raise RuntimeError("fit() hasn't been called.")
1184+
if not self.fitted:
1185+
raise ValueError("You need to call fit() before.")
1186+
11881187
return self._time
11891188

11901189
@property
@@ -1195,6 +1194,9 @@ def atilde(self):
11951194
:return: the reduced Koopman operator A.
11961195
:rtype: numpy.ndarray
11971196
"""
1197+
if not self.fitted:
1198+
raise ValueError("You need to call fit() before.")
1199+
11981200
return self.operator.as_numpy_array
11991201

12001202
@property
@@ -1205,6 +1207,9 @@ def A(self):
12051207
:return: the full Koopman operator A.
12061208
:rtype: numpy.ndarray
12071209
"""
1210+
if not self.fitted:
1211+
raise ValueError("You need to call fit() before.")
1212+
12081213
return self.operator.A
12091214

12101215
@property
@@ -1215,7 +1220,7 @@ def dynamics(self):
12151220
:return: matrix that contains all the time evolution, stored by row.
12161221
:rtype: numpy.ndarray
12171222
"""
1218-
t_omega = np.exp(np.outer(self.eigs, self._time))
1223+
t_omega = np.exp(np.outer(self.eigs, self.time))
12191224
return np.diag(self.amplitudes).dot(t_omega)
12201225

12211226
@property
@@ -1226,6 +1231,9 @@ def amplitudes_std(self):
12261231
:return: amplitudes standard deviation.
12271232
:rtype: numpy.ndarray
12281233
"""
1234+
if not self.fitted:
1235+
raise ValueError("You need to call fit() before.")
1236+
12291237
return self.operator.amplitudes_std
12301238

12311239
@property
@@ -1236,6 +1244,9 @@ def eigenvalues_std(self):
12361244
:return: eigenvalues standard deviation.
12371245
:rtype: numpy.ndarray
12381246
"""
1247+
if not self.fitted:
1248+
raise ValueError("You need to call fit() before.")
1249+
12391250
return self.operator.eigenvalues_std
12401251

12411252
@property
@@ -1246,6 +1257,9 @@ def modes_std(self):
12461257
:return: modes standard deviation.
12471258
:rtype: numpy.ndarray
12481259
"""
1260+
if not self.fitted:
1261+
raise ValueError("You need to call fit() before.")
1262+
12491263
return self.operator.modes_std
12501264

12511265
@property
@@ -1263,8 +1277,8 @@ def print_varpro_opts(self):
12631277
Prints a formatted information string that displays all chosen
12641278
variable projection parameter values.
12651279
"""
1266-
if self._Atilde is None:
1267-
raise ValueError("You need to call fit before")
1280+
if not self.fitted:
1281+
raise ValueError("You need to call fit() before.")
12681282

12691283
opt_names = [
12701284
"init_lambda",

tests/test_bopdmd.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_A():
125125
bopdmd.fit(Z_uneven, t_uneven)
126126
assert compute_error(bopdmd.A, expected_A) < 1e-3
127127

128-
bopdmd = BOPDMD(svd_rank=2, compute_A=True, varpro_opts_dict={"tol": 0.05})
128+
bopdmd = BOPDMD(svd_rank=2, compute_A=True)
129129
bopdmd.fit(Z_noisy, t)
130130
assert compute_error(bopdmd.A, expected_A) < 1e-3
131131

@@ -444,6 +444,19 @@ def test_bag_int():
444444
bopdmd.fit(Z, t)
445445

446446

447+
def test_bag_getters():
448+
"""
449+
Test calls to the num_trials and trial_size parameters.
450+
"""
451+
bopdmd = BOPDMD(svd_rank=2, num_trials=0, trial_size=0.8)
452+
assert bopdmd.num_trials == 0
453+
assert bopdmd.trial_size == 0.8
454+
455+
bopdmd = BOPDMD(svd_rank=2, num_trials=100, trial_size=3200)
456+
assert bopdmd.num_trials == 100
457+
assert bopdmd.trial_size == 3200
458+
459+
447460
def test_bag_error():
448461
"""
449462
Test that errors are thrown if invalid bagging parameters are given.
@@ -498,7 +511,7 @@ def test_init_alpha_initializer():
498511
bopdmd = BOPDMD(svd_rank=2)
499512

500513
# Initial eigs shouldn't be defined yet.
501-
with raises(RuntimeError):
514+
with raises(ValueError):
502515
_ = bopdmd.init_alpha
503516

504517
# After fitting, the initial eigs used should be fairly accurate.
@@ -517,7 +530,7 @@ def test_proj_basis_initializer():
517530
bopdmd = BOPDMD(svd_rank=2)
518531

519532
# Projection basis shouldn't be defined yet.
520-
with raises(RuntimeError):
533+
with raises(ValueError):
521534
_ = bopdmd.proj_basis
522535

523536
# After fitting, the projection basis used should be accurate.
@@ -587,3 +600,41 @@ def test_std_shape():
587600
assert bopdmd.eigenvalues_std.shape == bopdmd.eigs.shape
588601
assert bopdmd.modes_std.shape == bopdmd.modes.shape
589602
assert bopdmd.amplitudes_std.shape == bopdmd.amplitudes.shape
603+
604+
605+
def test_std_nobags():
606+
"""
607+
Test that std attributes are simply None if no bags were used.
608+
"""
609+
bopdmd = BOPDMD(svd_rank=2, num_trials=0)
610+
bopdmd.fit(Z, t)
611+
612+
assert bopdmd.eigenvalues_std is None
613+
assert bopdmd.modes_std is None
614+
assert bopdmd.amplitudes_std is None
615+
616+
617+
def test_getter_errors():
618+
"""
619+
Test that error occurs if the following properties are called before fit:
620+
time, atilde, A, eigenvalues_std, modes_std, amplitudes_std.
621+
"""
622+
bopdmd = BOPDMD(compute_A=True)
623+
624+
with raises(ValueError):
625+
_ = bopdmd.time
626+
627+
with raises(ValueError):
628+
_ = bopdmd.atilde
629+
630+
with raises(ValueError):
631+
_ = bopdmd.A
632+
633+
with raises(ValueError):
634+
_ = bopdmd.eigenvalues_std
635+
636+
with raises(ValueError):
637+
_ = bopdmd.modes_std
638+
639+
with raises(ValueError):
640+
_ = bopdmd.amplitudes_std

0 commit comments

Comments
 (0)