Skip to content

Commit b917d0e

Browse files
Merge pull request #75 from antoinedemathelin/master
Add TCA
2 parents ed3e55e + 529087b commit b917d0e

File tree

4 files changed

+201
-4
lines changed

4 files changed

+201
-4
lines changed

adapt/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,15 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
476476
if hasattr(self, "fit_weights"):
477477
if self.verbose:
478478
print("Fit weights...")
479-
self.weights_ = self.fit_weights(Xs=X, Xt=Xt,
480-
ys=y, yt=yt,
481-
domains=domains)
479+
out = self.fit_weights(Xs=X, Xt=Xt,
480+
ys=y, yt=yt,
481+
domains=domains)
482+
if isinstance(out, tuple):
483+
self.weights_ = out[0]
484+
X = out[1]
485+
y = out[2]
486+
else:
487+
self.weights_ = out
482488
if "sample_weight" in fit_params:
483489
fit_params["sample_weight"] *= self.weights_
484490
else:

adapt/feature_based/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ._sa import SA
1515
from ._fmmd import fMMD
1616
from ._ccsa import CCSA
17+
from ._tca import TCA
1718

1819
__all__ = ["FA", "CORAL", "DeepCORAL", "ADDA", "DANN",
19-
"MCD", "MDD", "WDGRL", "CDAN", "SA", "fMMD", "CCSA"]
20+
"MCD", "MDD", "WDGRL", "CDAN", "SA", "fMMD", "CCSA", "TCA"]

adapt/feature_based/_tca.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""
2+
TCA
3+
"""
4+
5+
import numpy as np
6+
from sklearn.utils import check_array
7+
from sklearn.metrics import pairwise
8+
from sklearn.metrics.pairwise import KERNEL_PARAMS
9+
from scipy import linalg
10+
11+
from adapt.base import BaseAdaptEstimator, make_insert_doc
12+
from adapt.utils import set_random_seed
13+
14+
15+
@make_insert_doc()
16+
class TCA(BaseAdaptEstimator):
17+
"""
18+
TCA : Transfer Component Analysis
19+
20+
Parameters
21+
----------
22+
n_components : int or float (default=None)
23+
Number of components to keep.
24+
25+
mu : float (default=0.1)
26+
Regularization parameter. The larger
27+
``mu`` is, the less adaptation is performed.
28+
29+
Attributes
30+
----------
31+
estimator_ : object
32+
Estimator.
33+
34+
Examples
35+
--------
36+
>>> from sklearn.linear_model import RidgeClassifier
37+
>>> from adapt.utils import make_classification_da
38+
>>> from adapt.feature_based import TCA
39+
>>> Xs, ys, Xt, yt = make_classification_da()
40+
>>> model = TCA(RidgeClassifier(), Xt=Xt, n_components=1, mu=0.1,
41+
... kernel="rbf", gamma=0.1, verbose=0, random_state=0)
42+
>>> model.fit(Xs, ys)
43+
>>> model.score(Xt, yt)
44+
0.93
45+
46+
See also
47+
--------
48+
CORAL
49+
FA
50+
51+
References
52+
----------
53+
.. [1] `[1] <https://www.cse.ust.hk/~qyang/Docs/2009/TCA.pdf>` S. J. Pan, \
54+
I. W. Tsang, J. T. Kwok and Q. Yang. "Domain Adaptation via Transfer Component \
55+
Analysis". In IEEE transactions on neural networks 2010
56+
"""
57+
def __init__(self,
58+
estimator=None,
59+
Xt=None,
60+
n_components=20,
61+
mu=0.1,
62+
kernel="rbf",
63+
copy=True,
64+
verbose=1,
65+
random_state=None,
66+
**params):
67+
68+
names = self._get_param_names()
69+
kwargs = {k: v for k, v in locals().items() if k in names}
70+
kwargs.update(params)
71+
super().__init__(**kwargs)
72+
73+
74+
def fit_transform(self, Xs, Xt, **kwargs):
75+
"""
76+
Fit embeddings.
77+
78+
Parameters
79+
----------
80+
Xs : array
81+
Input source data.
82+
83+
Xt : array
84+
Input target data.
85+
86+
kwargs : key, value argument
87+
Not used, present here for adapt consistency.
88+
89+
Returns
90+
-------
91+
Xs_emb : embedded source data
92+
"""
93+
Xs = check_array(Xs)
94+
Xt = check_array(Xt)
95+
set_random_seed(self.random_state)
96+
97+
self.Xs_ = Xs
98+
self.Xt_ = Xt
99+
100+
n = len(Xs)
101+
m = len(Xt)
102+
103+
# Compute Kernel Matrix K
104+
kernel_params = {k: v for k, v in self.__dict__.items()
105+
if k in KERNEL_PARAMS[self.kernel]}
106+
107+
Kss = pairwise.pairwise_kernels(Xs, Xs, metric=self.kernel, **kernel_params)
108+
Ktt = pairwise.pairwise_kernels(Xt, Xt, metric=self.kernel, **kernel_params)
109+
Kst = pairwise.pairwise_kernels(Xs, Xt, metric=self.kernel, **kernel_params)
110+
111+
K = np.concatenate((Kss, Kst), axis=1)
112+
K = np.concatenate((K, np.concatenate((Kst.transpose(), Ktt), axis=1)), axis=0)
113+
114+
# Compute L
115+
Lss = np.ones((n,n)) * (1./(n**2))
116+
Ltt = np.ones((m,m)) * (1./(m**2))
117+
Lst = np.ones((n,m)) * (-1./(n*m))
118+
119+
L = np.concatenate((Lss, Lst), axis=1)
120+
L = np.concatenate((L, np.concatenate((Lst.transpose(), Ltt), axis=1)), axis=0)
121+
122+
# Compute H
123+
H = np.eye(n+m) - 1/(n+m) * np.ones((n+m, n+m))
124+
125+
# Compute solution
126+
a = np.eye(n+m) + self.mu * K.dot(L.dot(K))
127+
b = K.dot(H.dot(K))
128+
sol = linalg.lstsq(a, b)[0]
129+
130+
values, vectors = linalg.eigh(sol)
131+
132+
args = np.argsort(np.abs(values))[::-1][:self.n_components]
133+
134+
self.vectors_ = np.real(vectors[:, args])
135+
136+
Xs_enc = K.dot(self.vectors_)[:n]
137+
138+
return Xs_enc
139+
140+
141+
def transform(self, X, domain="tgt"):
142+
"""
143+
Return aligned features for X.
144+
145+
Parameters
146+
----------
147+
X : array
148+
Input data.
149+
150+
domain : str (default="tgt")
151+
Choose between ``"source", "src"`` or
152+
``"target", "tgt"`` feature embedding.
153+
154+
Returns
155+
-------
156+
X_emb : array
157+
Embeddings of X.
158+
"""
159+
X = check_array(X)
160+
161+
kernel_params = {k: v for k, v in self.__dict__.items()
162+
if k in KERNEL_PARAMS[self.kernel]}
163+
164+
Kss = pairwise.pairwise_kernels(X, self.Xs_, metric=self.kernel, **kernel_params)
165+
Kst = pairwise.pairwise_kernels(X, self.Xt_, metric=self.kernel, **kernel_params)
166+
167+
K = np.concatenate((Kss, Kst), axis=1)
168+
169+
return K.dot(self.vectors_)

tests/test_tca.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
3+
from adapt.metrics import normalized_linear_discrepancy
4+
from adapt.feature_based import TCA
5+
6+
np.random.seed(0)
7+
n = 50
8+
m = 50
9+
p = 6
10+
11+
Xs = np.random.randn(m, p)*0.1 + np.array([0.]*(p-2) + [2., 2.])
12+
Xt = np.random.randn(n, p)*0.1
13+
14+
15+
def test_tca():
16+
tca = TCA(n_components=2, kernel="rbf", gamma=0.01, random_state=0)
17+
Xst = tca.fit_transform(Xs, Xt)
18+
assert np.abs(Xst - tca.transform(Xs, "src")).sum() < 10**-8
19+
assert Xst.shape[1] == 2
20+
assert (normalized_linear_discrepancy(Xs, Xt) >
21+
2 * normalized_linear_discrepancy(Xst, tca.transform(Xt)))

0 commit comments

Comments
 (0)