Skip to content

Commit 2c2f0ba

Browse files
Modif KLIEP original
1 parent 2480999 commit 2c2f0ba

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

adapt/instance_based/_kliep.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def _fit(self, Xs, Xt, kernel_params):
360360
else :
361361
raise ValueError("%s is not a valid value of algo"%self.algo)
362362

363+
363364
def _fit_PG(self, Xs, Xt, PG, kernel_params):
364365
alphas = []
365366
OBJs = []
@@ -371,12 +372,34 @@ def _fit_PG(self, Xs, Xt, PG, kernel_params):
371372
else:
372373
raise TypeError("invalid argument type for lr")
373374

374-
centers, A, b = self.centers_selection(Xs, Xt, kernel_params)
375+
# For original, no center selection
376+
if PG:
377+
centers, A, b = self.centers_selection(Xs, Xt, kernel_params)
378+
else:
379+
index_centers = np.random.choice(
380+
len(Xt),
381+
min(len(Xt), self.max_centers),
382+
replace=False)
383+
centers = Xt[index_centers]
384+
385+
A = pairwise.pairwise_kernels(Xt, centers, metric=self.kernel,
386+
**kernel_params)
387+
B = pairwise.pairwise_kernels(centers, Xs, metric=self.kernel,
388+
**kernel_params)
389+
b = np.mean(B, axis=1)
390+
b = b.reshape(-1, 1)
375391

376392
for lr in LRs:
377393
if self.verbose > 1:
378394
print("learning rate : %s"%lr)
379-
alpha = 1/(len(centers)*b)
395+
396+
# For original, init alpha = ones and project
397+
if PG:
398+
alpha = 1/(len(centers)*b)
399+
else:
400+
alpha = np.ones((len(centers), 1))
401+
alpha = self._projection_original(alpha, b)
402+
380403
alpha = alpha.reshape(-1,1)
381404
previous_objective = -np.inf
382405
objective = np.sum(np.log(np.dot(A, alpha) + EPS))

0 commit comments

Comments
 (0)