Skip to content

Commit 839732b

Browse files
committed
Added mode prox + bag fail update
1 parent 9fa1e5e commit 839732b

File tree

1 file changed

+45
-38
lines changed

1 file changed

+45
-38
lines changed

pydmd/bopdmd.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ class BOPDMDOperator(DMDOperator):
7272
function that will be applied to the computed eigenvalues at each step
7373
of the variable projection routine.
7474
:type eig_constraints: set(str) or function
75-
:param bag_warning: Number of consecutive non-converged trials of BOP-DMD
76-
at which to produce a warning message for the user. Default is 100.
77-
Use arguments less than or equal to zero for no warning condition.
78-
:type bag_warning: int
75+
:param mode_prox: Optional proximal operator function to apply to the DMD
76+
modes at every iteration of variable projection routine.
77+
:type mode_prox: function
7978
:param bag_maxfail: Number of consecutive non-converged trials of BOP-DMD
80-
at which to terminate the fit. Default is -1, no stopping condition.
81-
:type bag_maxfail: int
79+
at which to terminate the fit. Set this parameter to infinity for no
80+
stopping condition. Set to a non-positive value to simply use the
81+
results of the non-converged trials. This is the default behavior.
82+
:type bag_maxfail: int or float
8283
:param init_lambda: Initial value used for the regularization parameter in
8384
the Levenberg method. Default is 1.0.
8485
Note: Larger lambda values make the method more like gradient descent.
@@ -128,7 +129,7 @@ def __init__(
128129
trial_size,
129130
eig_sort,
130131
eig_constraints,
131-
bag_warning,
132+
mode_prox,
132133
bag_maxfail,
133134
init_lambda=1.0,
134135
maxlam=52,
@@ -148,7 +149,7 @@ def __init__(
148149
self._trial_size = trial_size
149150
self._eig_sort = eig_sort
150151
self._eig_constraints = eig_constraints
151-
self._bag_warning = bag_warning
152+
self._mode_prox = mode_prox
152153
self._bag_maxfail = bag_maxfail
153154
self._varpro_opts = [
154155
init_lambda,
@@ -503,9 +504,15 @@ def compute_residual(alpha):
503504
which is used as an error indicator.
504505
"""
505506
Phi_matrix = Phi(alpha, t)
507+
508+
# Update B matrix.
506509
B = np.linalg.lstsq(Phi_matrix, H, rcond=None)[0]
510+
if self._mode_prox is not None:
511+
B = self._mode_prox(B)
512+
507513
residual = H - Phi_matrix.dot(B)
508514
error = 0.5 * np.linalg.norm(residual) ** 2
515+
509516
return B, residual, error
510517

511518
# Define M, IS, and IA.
@@ -759,13 +766,22 @@ def compute_operator(self, H, t):
759766
self._A = A_0
760767
return b_0
761768

762-
# Otherwise, perform BOP-DMD.
769+
# Otherwise, perform BOP-DMD:
763770
verbose = self._varpro_opts[-1]
764771
if verbose:
765772
num_trial_print = 5
766773
msg = "\nDisplaying the results of the next {} trials...\n"
767774
print(msg.format(num_trial_print))
768775

776+
# We'll consider non-converged trials successful if the user didn't
777+
# request a positive amount of bagging trials at which to terminate.
778+
use_bad_bags = self._bag_maxfail <= 0.0
779+
if verbose:
780+
if use_bad_bags:
781+
print("Using all bag trial results...\n")
782+
else:
783+
print("Non-converged trial results will be removed...\n")
784+
769785
# Initialize storage for values needed for stat computations.
770786
w_sum = np.zeros(w_0.shape, dtype="complex")
771787
e_sum = np.zeros(e_0.shape, dtype="complex")
@@ -791,9 +807,8 @@ def compute_operator(self, H, t):
791807
verbose = num_trial_print > 0
792808
self._varpro_opts[-1] = verbose
793809

794-
# Incorporate results into the running average
795-
# ONLY IF the trial successfully converged.
796-
if converged:
810+
# Incorporate trial results into the running average if successful.
811+
if converged or use_bad_bags:
797812
sorted_inds = self._argsort_eigenvalues(e_i)
798813

799814
# Add to iterative sums.
@@ -810,22 +825,21 @@ def compute_operator(self, H, t):
810825
# and reset the consecutive fails counter.
811826
num_successful_trials += 1
812827
num_consecutive_fails = 0
828+
829+
# Trial did not converge, and we are throwing away bad bags.
813830
else:
814831
num_consecutive_fails += 1
815832

816-
if (
817-
not runtime_warning_given
818-
and num_consecutive_fails == self._bag_warning
819-
):
833+
if not runtime_warning_given and num_consecutive_fails == 100:
820834
msg = (
821-
"{} many trials without convergence. "
835+
"100 trials without convergence. "
822836
"Consider loosening the tol requirements "
823837
"of the variable projection routine."
824838
)
825-
print(msg.format(num_consecutive_fails))
839+
print(msg)
826840
runtime_warning_given = True
827841

828-
elif num_consecutive_fails == self._bag_maxfail:
842+
if num_consecutive_fails >= self._bag_maxfail and not use_bad_bags:
829843
msg = (
830844
"Terminating the bagging routine due to "
831845
"{} many trials without convergence."
@@ -924,13 +938,14 @@ class BOPDMD(DMDBase):
924938
function that will be applied to the computed eigenvalues at each step
925939
of the variable projection routine.
926940
:type eig_constraints: set(str) or function
927-
:param bag_warning: Number of consecutive non-converged trials of BOP-DMD
928-
at which to produce a warning message for the user. Default is 100.
929-
Use arguments less than or equal to zero for no warning condition.
930-
:type bag_warning: int
941+
:param mode_prox: Optional proximal operator function to apply to the DMD
942+
modes at every iteration of variable projection routine.
943+
:type mode_prox: function
931944
:param bag_maxfail: Number of consecutive non-converged trials of BOP-DMD
932-
at which to terminate the fit. Default is -1, no stopping condition.
933-
:type bag_maxfail: int
945+
at which to terminate the fit. Set this parameter to infinity for no
946+
stopping condition. Set to a non-positive value to simply use the
947+
results of the non-converged trials. This is the default behavior.
948+
:type bag_maxfail: int or float
934949
:param varpro_opts_dict: Dictionary containing the desired parameter values
935950
for variable projection. The following parameters may be specified:
936951
`init_lambda`, `maxlam`, `lamup`, `use_levmarq`, `maxiter`, `tol`,
@@ -952,8 +967,8 @@ def __init__(
952967
trial_size=0.6,
953968
eig_sort="auto",
954969
eig_constraints=None,
955-
bag_warning=100,
956-
bag_maxfail=-1,
970+
mode_prox=None,
971+
bag_maxfail=0,
957972
varpro_opts_dict=None,
958973
):
959974
self._svd_rank = svd_rank
@@ -965,16 +980,6 @@ def __init__(
965980
self._trial_size = trial_size
966981
self._eig_sort = eig_sort
967982

968-
if not isinstance(bag_warning, int) or not isinstance(bag_maxfail, int):
969-
msg = (
970-
"bag_warning and bag_maxfail must be integers. "
971-
"Please use a non-positive integer if no warning "
972-
"or stopping condition is desired."
973-
)
974-
raise TypeError(msg)
975-
self._bag_warning = bag_warning
976-
self._bag_maxfail = bag_maxfail
977-
978983
if varpro_opts_dict is None:
979984
self._varpro_opts_dict = {}
980985
elif not isinstance(varpro_opts_dict, dict):
@@ -990,6 +995,8 @@ def __init__(
990995
raise TypeError("eig_constraints must be a set or a function.")
991996
self._check_eig_constraints(eig_constraints)
992997
self._eig_constraints = eig_constraints
998+
self._mode_prox = mode_prox
999+
self._bag_maxfail = bag_maxfail
9931000

9941001
self._snapshots_holder = None
9951002
self._time = None
@@ -1342,7 +1349,7 @@ def fit(self, X, t):
13421349
self._trial_size,
13431350
self._eig_sort,
13441351
self._eig_constraints,
1345-
self._bag_warning,
1352+
self._mode_prox,
13461353
self._bag_maxfail,
13471354
**self._varpro_opts_dict,
13481355
)

0 commit comments

Comments
 (0)