Skip to content

Commit 54d6ce1

Browse files
authored
Add Nesterov acceleration (#69)
2 parents dd39e2d + b1a1ed1 commit 54d6ce1

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

tensorpack/coupled.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from copy import deepcopy
12
import xarray as xr
23
import numpy as np
34
import tensorly as tl
@@ -23,7 +24,6 @@ def xr_unfold(data: xr.Dataset, mode: str):
2324
round_to_n = lambda x, n: x if x == 0 else round(x, -int(np.floor(np.log10(abs(x)))) + (n - 1))
2425
vround2 = np.vectorize(lambda x: round_to_n(x, 2))
2526

26-
2727
class CoupledTensor():
2828
def __init__(self, data: xr.Dataset, rank):
2929
if not isinstance(data, xr.Dataset):
@@ -63,7 +63,7 @@ def initialize(self, method="svd", verbose=False):
6363
# wipe off old values
6464
self.x["_Weight_"][:] = np.ones_like(self.x["_Weight_"])
6565
for mmode in self.modes:
66-
self.x["_" + mmode][:] = np.zeros_like(self.x["_" + mmode])
66+
self.x["_" + mmode][:] = np.ones_like(self.x["_" + mmode])
6767

6868
if method == "ones":
6969
for mmode in self.modes:
@@ -194,20 +194,52 @@ def fit(self, tol=1e-7, maxiter=500, nonneg=False, progress=True, verbose=False)
194194
old_R2X = -np.inf
195195
tq = tqdm(range(maxiter), disable=(not progress))
196196

197+
gamma = 1.1
198+
gamma_bar = 1.03
199+
eta = 1.5
200+
beta_i = 0.05
201+
beta_i_bar = 1.0
202+
197203
# missing value handling
198204
uniqueInfo = {}
199205
for mmode in self.modes:
200206
uniqueInfo[mmode] = np.unique(np.isfinite(self.unfold[mmode].T), axis=1, return_inverse=True)
201207

202208
for i in tq:
209+
jump = beta_i + 1.0
210+
x_prev = deepcopy(self.x)
211+
203212
# Solve on each mode
204213
for mmode in self.modes:
205214
self.x["_"+mmode][:] = mlstsq(self.khatri_rao(mmode), self.unfold[mmode].T, uniqueInfo[mmode], nonneg=nonneg).T
206215
self.normalize_factors("norm")
216+
217+
# line search
218+
x_ls = deepcopy(self.x)
219+
220+
for mmode in self.modes:
221+
x_ls["_"+mmode][:] = x_prev["_"+mmode] + jump * (self.x["_"+mmode] - x_prev["_"+mmode])
222+
223+
x_cur = deepcopy(self.x)
207224
current_R2X = self.R2X()
225+
226+
self.x = x_ls
227+
ls_R2X = self.R2X()
228+
229+
if ls_R2X > current_R2X:
230+
current_R2X = ls_R2X
231+
self.x = x_ls
232+
233+
beta_i = min(beta_i_bar, gamma * beta_i)
234+
beta_i_bar = max(1.0, gamma_bar * beta_i_bar)
235+
else:
236+
beta_i_bar = beta_i
237+
beta_i = beta_i / eta
238+
self.x = x_cur
239+
208240
if verbose:
209241
print(f"R2Xs at {i}: {[self.R2X(dvar) for dvar in self.dvars]}")
210-
tq.set_postfix(refresh=False, R2X=current_R2X, delta=current_R2X-old_R2X)
242+
tq.set_postfix(refresh=False, R2X=current_R2X, delta=current_R2X-old_R2X, jump=jump)
211243
if np.abs(current_R2X-old_R2X) < tol:
212244
break
213245
old_R2X = current_R2X

0 commit comments

Comments
 (0)