Skip to content

Commit 91b192f

Browse files
committed
Added option to ignore failed bags
1 parent 57c91c9 commit 91b192f

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed

pydmd/bopdmd.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,10 @@ class BOPDMDOperator(DMDOperator):
7676
modes at every iteration of variable projection routine.
7777
:type mode_prox: function
7878
:param bag_maxfail: Number of consecutive non-converged trials of BOP-DMD
79-
at which to terminate the fit. Default is -1, no stopping condition.
80-
:type bag_maxfail: int
79+
at which to terminate the fit. Set this parameter to infinity for no
80+
stopping condition. Set to a non-positive value to simply use the
81+
results of the non-converged trials. This is the default behavior.
82+
:type bag_maxfail: int or float
8183
:param init_lambda: Initial value used for the regularization parameter in
8284
the Levenberg method. Default is 1.0.
8385
Note: Larger lambda values make the method more like gradient descent.
@@ -764,13 +766,22 @@ def compute_operator(self, H, t):
764766
self._A = A_0
765767
return b_0
766768

767-
# Otherwise, perform BOP-DMD.
769+
# Otherwise, perform BOP-DMD:
768770
verbose = self._varpro_opts[-1]
769771
if verbose:
770772
num_trial_print = 5
771773
msg = "\nDisplaying the results of the next {} trials...\n"
772774
print(msg.format(num_trial_print))
773775

776+
# We'll consider non-converged trials successful if the user didn't
777+
# request a positive amount of bagging trials at which to terminate.
778+
use_bad_bags = self._bag_maxfail <= 0.0
779+
if verbose:
780+
if use_bad_bags:
781+
print("Using all bag trial results...\n")
782+
else:
783+
print("Non-converged trial results will be removed...\n")
784+
774785
# Initialize storage for values needed for stat computations.
775786
w_sum = np.zeros(w_0.shape, dtype="complex")
776787
e_sum = np.zeros(e_0.shape, dtype="complex")
@@ -796,9 +807,8 @@ def compute_operator(self, H, t):
796807
verbose = num_trial_print > 0
797808
self._varpro_opts[-1] = verbose
798809

799-
# Incorporate results into the running average
800-
# ONLY IF the trial successfully converged.
801-
if converged:
810+
# Incorporate trial results into the running average if successful.
811+
if converged or use_bad_bags:
802812
sorted_inds = self._argsort_eigenvalues(e_i)
803813

804814
# Add to iterative sums.
@@ -815,13 +825,12 @@ def compute_operator(self, H, t):
815825
# and reset the consecutive fails counter.
816826
num_successful_trials += 1
817827
num_consecutive_fails = 0
828+
829+
# Trial did not converge, and we are throwing away bad bags.
818830
else:
819831
num_consecutive_fails += 1
820832

821-
if (
822-
not runtime_warning_given
823-
and num_consecutive_fails == 100
824-
):
833+
if not runtime_warning_given and num_consecutive_fails == 100:
825834
msg = (
826835
"100 trials without convergence. "
827836
"Consider loosening the tol requirements "
@@ -830,7 +839,7 @@ def compute_operator(self, H, t):
830839
print(msg)
831840
runtime_warning_given = True
832841

833-
elif num_consecutive_fails == self._bag_maxfail:
842+
if num_consecutive_fails >= self._bag_maxfail and not use_bad_bags:
834843
msg = (
835844
"Terminating the bagging routine due to "
836845
"{} many trials without convergence."
@@ -933,8 +942,10 @@ class BOPDMD(DMDBase):
933942
modes at every iteration of variable projection routine.
934943
:type mode_prox: function
935944
:param bag_maxfail: Number of consecutive non-converged trials of BOP-DMD
936-
at which to terminate the fit. Default is -1, no stopping condition.
937-
:type bag_maxfail: int
945+
at which to terminate the fit. Set this parameter to infinity for no
946+
stopping condition. Set to a non-positive value to simply use the
947+
results of the non-converged trials. This is the default behavior.
948+
:type bag_maxfail: int or float
938949
:param varpro_opts_dict: Dictionary containing the desired parameter values
939950
for variable projection. The following parameters may be specified:
940951
`init_lambda`, `maxlam`, `lamup`, `use_levmarq`, `maxiter`, `tol`,
@@ -957,7 +968,7 @@ def __init__(
957968
eig_sort="auto",
958969
eig_constraints=None,
959970
mode_prox=None,
960-
bag_maxfail=-1,
971+
bag_maxfail=0,
961972
varpro_opts_dict=None,
962973
):
963974
self._svd_rank = svd_rank
@@ -969,15 +980,6 @@ def __init__(
969980
self._trial_size = trial_size
970981
self._eig_sort = eig_sort
971982

972-
if not isinstance(bag_maxfail, int):
973-
msg = (
974-
"bag_maxfail must be an integer. "
975-
"Please use a non-positive integer if no warning "
976-
"or stopping condition is desired."
977-
)
978-
raise TypeError(msg)
979-
self._bag_maxfail = bag_maxfail
980-
981983
if varpro_opts_dict is None:
982984
self._varpro_opts_dict = {}
983985
elif not isinstance(varpro_opts_dict, dict):
@@ -994,6 +996,7 @@ def __init__(
994996
self._check_eig_constraints(eig_constraints)
995997
self._eig_constraints = eig_constraints
996998
self._mode_prox = mode_prox
999+
self._bag_maxfail = bag_maxfail
9971000

9981001
self._snapshots_holder = None
9991002
self._time = None

0 commit comments

Comments
 (0)