Skip to content

Commit 896be13

Browse files
ULSIF and RULSIF
Instance based methods implementation , comments to be changed
1 parent 93a565d commit 896be13

File tree

2 files changed

+795
-0
lines changed

2 files changed

+795
-0
lines changed

adapt/instance_based/_RULSIF.py

Lines changed: 399 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,399 @@
1+
"""
2+
Kullback-Leibler Importance Estimation Procedure
3+
"""
4+
import itertools
5+
import warnings
6+
7+
import numpy as np
8+
from sklearn.metrics import pairwise
9+
from sklearn.exceptions import NotFittedError
10+
from sklearn.utils import check_array
11+
from sklearn.metrics.pairwise import KERNEL_PARAMS
12+
13+
from adapt.base import BaseAdaptEstimator, make_insert_doc
14+
from adapt.utils import set_random_seed
15+
16+
EPS = np.finfo(float).eps
17+
18+
class RULSIF(BaseAdaptEstimator):
19+
"""
20+
RULSIF: Relative least-squares importance fitting
21+
22+
RULSIF is an instance-based method for domain adaptation.
23+
24+
The purpose of the algorithm is to correct the difference between
25+
input distributions of source and target domains. This is done by
26+
finding a source instances **reweighting** which minimizes the
27+
**relative Person divergence** between source and target distributions.
28+
29+
The source instance weights are given by the following formula:
30+
31+
.. math::
32+
33+
w(x) = \sum_{x_i \in X_T} \\theta_i K(x, x_i)
34+
35+
Where:
36+
37+
- :math:`x, x_i` are input instances.
38+
- :math:`X_T` is the target input data.
39+
- :math:`\\theta_i` are the basis functions coefficients.
40+
- :math:`K(x, x_i) = \\text{exp}(-\\gamma ||x - x_i||^2)`
41+
for instance if ``kernel="rbf"``.
42+
43+
KLIEP algorithm consists in finding the optimal :math:`\\theta` according to
44+
the quadratic problem
45+
46+
.. math::
47+
48+
\max_{\theta } \frac{1}{2} \\theta^T H \\theta - h^T \\theta +
49+
\frac{\\lambda}{2} \\theta^T \\theta
50+
51+
where :
52+
53+
.. math::
54+
55+
H_{ll'}= \\frac{\\alpha}{n_s} \sum_{x_i \\in X_S} K(x_i, x_l) K(x_i, x_l') + \\frac{1-\\alpha}{n_t} \\sum_{x_i \\in X_T} K(x_i, x_l) K(x_i, x_l')
56+
h_{l}= \\frac{1}{n_T} \sum_{x_i \\in X_T} K(x_i, x_l)
57+
58+
Where:
59+
60+
- :math:`X_T` is the source input data of size :math:`n_T`.
61+
62+
The above OP is solved by the closed form expression
63+
64+
- :math:\hat{\\theta}=(H+\\lambda I_{n_s})^{(-1)} h
65+
66+
Furthemore the method admits a leave one out cross validation score that has a clossed expression
67+
and can be used to select the appropriate parameters of the kernel function :math:`K` (typically, the paramter
68+
:math:`\\gamma` of the Gaussian kernel). The parameter is then selected using
69+
cross-validation on the :math:`J` score defined as follows:
70+
:math:`J = -\\frac{\\alpha}{2|X_S|} \\sum_{x \\in X_S} w(x)^2 - \frac{1-\\alpha}{2|X_T|} \\sum_{x \in X_T} w(x)^2 `
71+
72+
Finally, an estimator is fitted using the reweighted labeled source instances.
73+
74+
RULSIF method has been originally introduced for **unsupervised**
75+
DA but it could be widen to **supervised** by simply adding labeled
76+
target data to the training set.
77+
78+
Parameters
79+
----------
80+
kernel : str (default="rbf")
81+
Kernel metric.
82+
Possible values: [‘additive_chi2’, ‘chi2’,
83+
‘linear’, ‘poly’, ‘polynomial’, ‘rbf’,
84+
‘laplacian’, ‘sigmoid’, ‘cosine’]
85+
86+
sigmas : float or list of float (default=None)
87+
Deprecated, please use the ``gamma`` parameter
88+
instead. (See below).
89+
90+
max_centers : int (default=100)
91+
Maximal number of target instances use to
92+
compute kernels.
93+
94+
95+
Yields
96+
------
97+
gamma : float or list of float
98+
Kernel parameter ``gamma``.
99+
100+
- For kernel = chi2::
101+
102+
k(x, y) = exp(-gamma Sum [(x - y)^2 / (x + y)])
103+
- For kernel = poly or polynomial::
104+
105+
K(X, Y) = (gamma <X, Y> + coef0)^degree
106+
107+
- For kernel = rbf::
108+
109+
K(x, y) = exp(-gamma ||x-y||^2)
110+
111+
- For kernel = laplacian::
112+
113+
K(x, y) = exp(-gamma ||x-y||_1)
114+
115+
- For kernel = sigmoid::
116+
117+
K(X, Y) = tanh(gamma <X, Y> + coef0)
118+
119+
If a list is given, the LCV process is performed to
120+
select the best parameter ``gamma``.
121+
122+
coef0 : floaf or list of float
123+
Kernel parameter ``coef0``.
124+
Used for ploynomial and sigmoid kernels.
125+
See ``gamma`` parameter above for the
126+
kernel formulas.
127+
If a list is given, the LCV process is performed to
128+
select the best parameter ``coef0``.
129+
130+
degree : int or list of int
131+
Degree parameter for the polynomial
132+
kernel. (see formula in the ``gamma``
133+
parameter description).
134+
If a list is given, the LCV process is performed to
135+
select the best parameter ``degree``.
136+
Attributes
137+
----------
138+
weights_ : numpy array
139+
Training instance weights.
140+
141+
best_params_ : float
142+
Best kernel params combination
143+
deduced from the LCV procedure.
144+
145+
thetas_ : numpy array
146+
Basis functions coefficients.
147+
148+
centers_ : numpy array
149+
Center points for kernels.
150+
151+
j_scores_ : dict
152+
dict of J scores with the
153+
kernel params combination as
154+
keys and the J scores as values.
155+
156+
estimator_ : object
157+
Fitted estimator.
158+
159+
Examples
160+
--------
161+
>>> import numpy as np
162+
>>> from adapt.instance_based import KLIEP
163+
>>> np.random.seed(0)
164+
>>> Xs = np.random.randn(50) * 0.1
165+
>>> Xs = np.concatenate((Xs, Xs + 1.))
166+
>>> Xt = np.random.randn(100) * 0.1
167+
>>> ys = np.array([-0.2 * x if x<0.5 else 1. for x in Xs])
168+
>>> yt = -0.2 * Xt
169+
>>> kliep = KLIEP(sigmas=[0.1, 1, 10], random_state=0)
170+
>>> kliep.fit_estimator(Xs.reshape(-1,1), ys)
171+
>>> np.abs(kliep.predict(Xt.reshape(-1,1)).ravel() - yt).mean()
172+
0.09388...
173+
>>> kliep.fit(Xs.reshape(-1,1), ys, Xt.reshape(-1,1))
174+
Fitting weights...
175+
Cross Validation process...
176+
Parameter sigma = 0.1000 -- J-score = 0.059 (0.001)
177+
Parameter sigma = 1.0000 -- J-score = 0.427 (0.003)
178+
Parameter sigma = 10.0000 -- J-score = 0.704 (0.017)
179+
Fitting estimator...
180+
>>> np.abs(kliep.predict(Xt.reshape(-1,1)).ravel() - yt).mean()
181+
0.00302...
182+
See also
183+
--------
184+
KMM
185+
References
186+
----------
187+
.. [1] `[1] <https://proceedings.neurips.cc/paper/2011/file/
188+
d1f255a373a3cef72e03aa9d980c7eca-Paper.pdf>`_ \
189+
M. Yamada, T. Suzuki, T. Kanamori, H. Hachiya and M. Sugiyama. \
190+
"Relative Density-Ratio Estimation
191+
for Robust Distribution Comparison". In NIPS 2011
192+
"""
193+
def __init__(self,
194+
estimator=None,
195+
Xt=None,
196+
alpha=0.1,
197+
kernel="rbf",
198+
sigmas=None,
199+
lambdas=None,
200+
max_centers=100,
201+
copy=True,
202+
verbose=1,
203+
random_state=None,
204+
**params):
205+
206+
if sigmas is not None:
207+
warnings.warn("The `sigmas` argument is deprecated, "
208+
"please use the `gamma` argument instead.",
209+
DeprecationWarning)
210+
211+
names = self._get_param_names()
212+
kwargs = {k: v for k, v in locals().items() if k in names}
213+
kwargs.update(params)
214+
super().__init__(**kwargs)
215+
216+
217+
def fit_weights(self, Xs, Xt, **kwargs):
218+
"""
219+
Fit importance weighting.
220+
221+
Parameters
222+
----------
223+
Xs : array
224+
Input source data.
225+
226+
Xt : array
227+
Input target data.
228+
229+
kwargs : key, value argument
230+
Not used, present here for adapt consistency.
231+
232+
Returns
233+
-------
234+
weights_ : sample weights
235+
"""
236+
Xs = check_array(Xs)
237+
Xt = check_array(Xt)
238+
239+
self.j_scores_ = {}
240+
241+
# LCV GridSearch
242+
kernel_params = {k: v for k, v in self.__dict__.items()
243+
if k in KERNEL_PARAMS[self.kernel]}
244+
245+
# Handle deprecated sigmas (will be removed)
246+
if (self.sigmas is not None) and (not "gamma" in kernel_params):
247+
kernel_params["gamma"] = self.sigmas
248+
249+
kernel_params_dict = {k:(v if hasattr(v, "__iter__") else [v]) for k, v in kernel_params.items()}
250+
lambdas_params_dict={"lamb":(self.lambdas if hasattr(self.lambdas, "__iter__") else [self.lambdas])}
251+
options = kernel_params_dict
252+
keys = options.keys()
253+
values = (options[key] for key in keys)
254+
params_comb_kernel = [dict(zip(keys, combination)) for combination in itertools.product(*values)]
255+
256+
if len(params_comb_kernel)*len(lambdas_params_dict["lamb"]) > 1:
257+
if self.verbose:
258+
print("Cross Validation process...")
259+
# Cross-validation process
260+
max_ = -np.inf
261+
N_s=len(Xs)
262+
N_t=len(Xt)
263+
N_min = min(N_s, N_t)
264+
index_centers = np.random.choice(
265+
len(Xt),
266+
min(len(Xt), self.max_centers),
267+
replace=False)
268+
centers = Xt[index_centers]
269+
n_centers=min(len(Xt), self.max_centers)
270+
271+
if N_s<N_t:
272+
index_data = np.random.choice(
273+
N_t,
274+
N_s,
275+
replace=False)
276+
elif N_t<N_s:
277+
index_data = np.random.choice(
278+
N_s,
279+
N_t,
280+
replace=False)
281+
282+
283+
for params in params_comb_kernel:
284+
285+
if N_s<N_t:
286+
phi_t = pairwise.pairwise_kernels(centers,Xt[index_data], metric=self.kernel,
287+
**params)
288+
phi_s = pairwise.pairwise_kernels(centers,Xs, metric=self.kernel,
289+
**params)
290+
elif N_t<N_s:
291+
phi_t = pairwise.pairwise_kernels(centers,Xt, metric=self.kernel,
292+
**params)
293+
phi_s = pairwise.pairwise_kernels(centers,Xs[index_data], metric=self.kernel,
294+
**params)
295+
else:
296+
phi_t = pairwise.pairwise_kernels(centers,Xt, metric=self.kernel,
297+
**params)
298+
phi_s = pairwise.pairwise_kernels(centers,Xs, metric=self.kernel,
299+
**params)
300+
301+
302+
H=self.alpha*np.dot(phi_t, phi_t.T) / N_t + (1-self.alpha)*np.dot(phi_s, phi_s.T) / N_s
303+
h = np.mean(phi_t, axis=1)
304+
h = h.reshape(-1, 1)
305+
306+
307+
for lamb in lambdas_params_dict["lamb"]:
308+
B = H + np.identity(n_centers) * (lamb * (N_t - 1) / N_t)
309+
BinvX = np.linalg.solve(B, phi_t)
310+
XBinvX = phi_t * BinvX
311+
D0 = np.ones(N_min) * N_t- np.dot(np.ones(n_centers), XBinvX)
312+
diag_D0 = np.diag((np.dot(h.T, BinvX) / D0).ravel())
313+
B0 = np.linalg.solve(B, h * np.ones(N_min)) + np.dot(BinvX, diag_D0)
314+
diag_D1 = np.diag(np.dot(np.ones(n_centers), phi_s * BinvX).ravel())
315+
B1 = np.linalg.solve(B, phi_s) + np.dot(BinvX, diag_D1)
316+
B2 = (N_t- 1) * (N_s* B0 - B1) / (N_t* (N_s - 1))
317+
B2[B2<0]=0
318+
r_s = (phi_s * B2).sum(axis=0).T
319+
r_t= (phi_t * B2).sum(axis=0).T
320+
score = ((1-self.alpha)*(np.dot(r_s.T, r_s).ravel() / 2. + self.alpha*np.dot(r_t.T, r_t).ravel() / 2. - r_t.sum(axis=0)) /N_min).item() # LOOCV
321+
aux_params={"k":params,"lamb":lamb}
322+
self.j_scores_[str(aux_params)]=-1*score
323+
324+
if self.verbose:
325+
print("Parameters %s -- J-score = %.3f"% (str(aux_params),score))
326+
if self.j_scores_[str(aux_params)] > max_:
327+
self.best_params_ = aux_params
328+
max_ = self.j_scores_[str(aux_params)]
329+
else:
330+
self.best_params_ = {"k":params_comb_kernel[0],"lamb": lambdas_params_dict["lamb"]}
331+
332+
333+
self.thetas_, self.centers_ = self._fit(Xs, Xt, self.best_params_["k"],self.best_params_["lamb"])
334+
335+
self.weights_ = np.dot(
336+
pairwise.pairwise_kernels(Xs, self.centers_,
337+
metric=self.kernel,
338+
**self.best_params_["k"]),
339+
self.thetas_
340+
).ravel()
341+
return self.weights_
342+
343+
344+
def predict_weights(self, X=None):
345+
"""
346+
Return fitted source weights
347+
348+
If ``None``, the fitted source weights are returned.
349+
Else, sample weights are computing using the fitted
350+
``thetas_`` and the chosen ``centers_``.
351+
352+
Parameters
353+
----------
354+
X : array (default=None)
355+
Input data.
356+
357+
Returns
358+
-------
359+
weights_ : sample weights
360+
"""
361+
if hasattr(self, "weights_"):
362+
if X is None or not hasattr(self, "thetas_"):
363+
return self.weights_
364+
else:
365+
X = check_array(X)
366+
weights = np.dot(
367+
pairwise.pairwise_kernels(X,self.centers_,
368+
metric=self.kernel,
369+
**self.best_params_["k"]),
370+
self.thetas_
371+
).ravel()
372+
return weights
373+
else:
374+
raise NotFittedError("Weights are not fitted yet, please "
375+
"call 'fit_weights' or 'fit' first.")
376+
377+
378+
def _fit(self, Xs, Xt, kernel_params,lamb):
379+
index_centers = np.random.choice(
380+
len(Xt),
381+
min(len(Xt), self.max_centers),
382+
replace=False)
383+
centers = Xt[index_centers]
384+
n_centers=min(len(Xt), self.max_centers)
385+
386+
phi_t = pairwise.pairwise_kernels( centers,Xt, metric=self.kernel,
387+
**kernel_params)
388+
phi_s = pairwise.pairwise_kernels(centers,Xs, metric=self.kernel,
389+
**kernel_params)
390+
391+
N_t=len(Xt)
392+
N_s=len(Xs)
393+
394+
H=self.alpha*np.dot(phi_t, phi_t.T) / N_t + (1-self.alpha)*np.dot(phi_s, phi_s.T) / N_s
395+
h = np.mean(phi_t, axis=1)
396+
h = h.reshape(-1, 1)
397+
theta = np.linalg.solve(H+lamb*np.eye(n_centers), h)
398+
theta[theta<0]=0
399+
return theta, centers

0 commit comments

Comments
 (0)