Skip to content

Commit 8937bca

Browse files
committed
Added new parameter
1 parent 0a6942d commit 8937bca

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

pydmd/bopdmd.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,20 @@ class BOPDMDOperator(DMDOperator):
7979
routine after the modes have been projected back to the space of the
8080
full input data.
8181
:type mode_prox: function
82+
:param remove_bad_bags: Whether or not to exclude results from bagging
83+
trials that didn't converge according to the tolerance used for
84+
variable projection. Default is False, all trial results are kept
85+
regardless of convergence.
86+
:type remove_bad_bags: bool
8287
:param bag_warning: Number of consecutive non-converged trials of BOP-DMD
8388
at which to produce a warning message for the user. Default is 100.
84-
Use arguments less than zero for no warning condition.
89+
This parameter becomes active only when `remove_bad_bags=True`. Use
90+
negative arguments for no warning condition.
8591
:type bag_warning: int
8692
:param bag_maxfail: Number of consecutive non-converged trials of BOP-DMD
87-
at which to terminate the fit. Default is 100. Use arguments less than
88-
zero for no stopping condition.
93+
at which to terminate the fit. Default is 100. This parameter becomes
94+
active only when `remove_bad_bags=True`. Use negative arguments for no
95+
stopping condition.
8996
:type bag_maxfail: int
9097
:param init_lambda: Initial value used for the regularization parameter in
9198
the Levenberg method. Default is 1.0.
@@ -137,6 +144,7 @@ def __init__(
137144
eig_sort,
138145
eig_constraints,
139146
mode_prox,
147+
remove_bad_bags,
140148
bag_warning,
141149
bag_maxfail,
142150
init_lambda=1.0,
@@ -158,6 +166,7 @@ def __init__(
158166
self._eig_sort = eig_sort
159167
self._eig_constraints = eig_constraints
160168
self._mode_prox = mode_prox
169+
self._remove_bad_bags = remove_bad_bags
161170
self._bag_warning = bag_warning
162171
self._bag_maxfail = bag_maxfail
163172
self._varpro_opts = [
@@ -314,7 +323,7 @@ def _push_eigenvalues(self, eigenvalues):
314323
unassigned_inds.remove(ind_2)
315324
# Average their real and imaginary components together.
316325
a = 0.5 * (eig_1.real + eig_2.real)
317-
b = 0.5 * (abs(eig_1.imag) + abs(eig_2.imag))
326+
b = 0.5 * (np.abs(eig_1.imag) + np.abs(eig_2.imag))
318327
new_eigs[ind_1] = a + 1j * (b * np.sign(eig_1.imag))
319328
new_eigs[ind_2] = a + 1j * (b * np.sign(eig_2.imag))
320329

@@ -800,13 +809,12 @@ def compute_operator(self, H, t):
800809
print(msg.format(num_trial_print))
801810

802811
# We'll consider non-converged trials successful if the user didn't
803-
# request a positive amount of bagging trials at which to terminate.
804-
keep_bad_bags = self._bag_maxfail <= 0.0
812+
# request to remove bad bags.
805813
if verbose:
806-
if keep_bad_bags:
807-
print("Using all bag trial results...\n")
808-
else:
814+
if self._remove_bad_bags:
809815
print("Non-converged trial results will be removed...\n")
816+
else:
817+
print("Using all bag trial results...\n")
810818

811819
# Initialize storage for values needed for stat computations.
812820
w_sum = np.zeros(w_0.shape, dtype="complex")
@@ -834,7 +842,7 @@ def compute_operator(self, H, t):
834842
self._varpro_opts[-1] = verbose
835843

836844
# Incorporate trial results into the running average if successful.
837-
if converged or keep_bad_bags:
845+
if converged or not self._remove_bad_bags:
838846
sorted_inds = self._argsort_eigenvalues(e_i)
839847

840848
# Add to iterative sums.
@@ -868,7 +876,7 @@ def compute_operator(self, H, t):
868876
print(msg.format(num_consecutive_fails))
869877
runtime_warning_given = True
870878

871-
if not keep_bad_bags and num_consecutive_fails == self._bag_maxfail:
879+
if self._remove_bad_bags and num_consecutive_fails == self._bag_maxfail:
872880
msg = (
873881
"Terminating the bagging routine due to "
874882
"{} many trials without convergence."
@@ -975,13 +983,20 @@ class BOPDMD(DMDBase):
975983
routine after the modes have been projected back to the space of the
976984
full input data.
977985
:type mode_prox: function
986+
:param remove_bad_bags: Whether or not to exclude results from bagging
987+
trials that didn't converge according to the tolerance used for
988+
variable projection. Default is False, all trial results are kept
989+
regardless of convergence.
990+
:type remove_bad_bags: bool
978991
:param bag_warning: Number of consecutive non-converged trials of BOP-DMD
979992
at which to produce a warning message for the user. Default is 100.
980-
Use arguments less than zero for no warning condition.
993+
This parameter becomes active only when `remove_bad_bags=True`. Use
994+
negative arguments for no warning condition.
981995
:type bag_warning: int
982996
:param bag_maxfail: Number of consecutive non-converged trials of BOP-DMD
983-
at which to terminate the fit. Default is 100. Use arguments less than
984-
zero for no stopping condition.
997+
at which to terminate the fit. Default is 100. This parameter becomes
998+
active only when `remove_bad_bags=True`. Use negative arguments for no
999+
stopping condition.
9851000
:type bag_maxfail: int
9861001
:param varpro_opts_dict: Dictionary containing the desired parameter values
9871002
for variable projection. The following parameters may be specified:
@@ -1005,6 +1020,7 @@ def __init__(
10051020
eig_sort="auto",
10061021
eig_constraints=None,
10071022
mode_prox=None,
1023+
remove_bad_bags=False,
10081024
bag_warning=100,
10091025
bag_maxfail=100,
10101026
varpro_opts_dict=None,
@@ -1025,6 +1041,7 @@ def __init__(
10251041
"or stopping condition is desired."
10261042
)
10271043
raise TypeError(msg)
1044+
self._remove_bad_bags = remove_bad_bags
10281045
self._bag_warning = bag_warning
10291046
self._bag_maxfail = bag_maxfail
10301047

@@ -1397,6 +1414,7 @@ def fit(self, X, t):
13971414
self._eig_sort,
13981415
self._eig_constraints,
13991416
self._mode_prox,
1417+
self._remove_bad_bags,
14001418
self._bag_warning,
14011419
self._bag_maxfail,
14021420
**self._varpro_opts_dict,

0 commit comments

Comments
 (0)