1+ from copy import deepcopy
12import xarray as xr
23import numpy as np
34import tensorly as tl
@@ -23,7 +24,6 @@ def xr_unfold(data: xr.Dataset, mode: str):
2324round_to_n = lambda x , n : x if x == 0 else round (x , - int (np .floor (np .log10 (abs (x )))) + (n - 1 ))
2425vround2 = np .vectorize (lambda x : round_to_n (x , 2 ))
2526
26-
2727class CoupledTensor ():
2828 def __init__ (self , data : xr .Dataset , rank ):
2929 if not isinstance (data , xr .Dataset ):
@@ -63,7 +63,7 @@ def initialize(self, method="svd", verbose=False):
6363 # wipe off old values
6464 self .x ["_Weight_" ][:] = np .ones_like (self .x ["_Weight_" ])
6565 for mmode in self .modes :
66- self .x ["_" + mmode ][:] = np .zeros_like (self .x ["_" + mmode ])
66+ self .x ["_" + mmode ][:] = np .ones_like (self .x ["_" + mmode ])
6767
6868 if method == "ones" :
6969 for mmode in self .modes :
@@ -194,20 +194,52 @@ def fit(self, tol=1e-7, maxiter=500, nonneg=False, progress=True, verbose=False)
194194 old_R2X = - np .inf
195195 tq = tqdm (range (maxiter ), disable = (not progress ))
196196
197+ gamma = 1.1
198+ gamma_bar = 1.03
199+ eta = 1.5
200+ beta_i = 0.05
201+ beta_i_bar = 1.0
202+
197203 # missing value handling
198204 uniqueInfo = {}
199205 for mmode in self .modes :
200206 uniqueInfo [mmode ] = np .unique (np .isfinite (self .unfold [mmode ].T ), axis = 1 , return_inverse = True )
201207
202208 for i in tq :
209+ jump = beta_i + 1.0
210+ x_prev = deepcopy (self .x )
211+
203212 # Solve on each mode
204213 for mmode in self .modes :
205214 self .x ["_" + mmode ][:] = mlstsq (self .khatri_rao (mmode ), self .unfold [mmode ].T , uniqueInfo [mmode ], nonneg = nonneg ).T
206215 self .normalize_factors ("norm" )
216+
217+ # line search
218+ x_ls = deepcopy (self .x )
219+
220+ for mmode in self .modes :
221+ x_ls ["_" + mmode ][:] = x_prev ["_" + mmode ] + jump * (self .x ["_" + mmode ] - x_prev ["_" + mmode ])
222+
223+ x_cur = deepcopy (self .x )
207224 current_R2X = self .R2X ()
225+
226+ self .x = x_ls
227+ ls_R2X = self .R2X ()
228+
229+ if ls_R2X > current_R2X :
230+ current_R2X = ls_R2X
231+ self .x = x_ls
232+
233+ beta_i = min (beta_i_bar , gamma * beta_i )
234+ beta_i_bar = max (1.0 , gamma_bar * beta_i_bar )
235+ else :
236+ beta_i_bar = beta_i
237+ beta_i = beta_i / eta
238+ self .x = x_cur
239+
208240 if verbose :
209241 print (f"R2Xs at { i } : { [self .R2X (dvar ) for dvar in self .dvars ]} " )
210- tq .set_postfix (refresh = False , R2X = current_R2X , delta = current_R2X - old_R2X )
242+ tq .set_postfix (refresh = False , R2X = current_R2X , delta = current_R2X - old_R2X , jump = jump )
211243 if np .abs (current_R2X - old_R2X ) < tol :
212244 break
213245 old_R2X = current_R2X
0 commit comments