1+ """
2+ TCA
3+ """
4+
5+ import numpy as np
6+ from sklearn .utils import check_array
7+ from sklearn .metrics import pairwise
8+ from sklearn .metrics .pairwise import KERNEL_PARAMS
9+ from scipy import linalg
10+
11+ from adapt .base import BaseAdaptEstimator , make_insert_doc
12+ from adapt .utils import set_random_seed
13+
14+
15+ @make_insert_doc ()
16+ class TCA (BaseAdaptEstimator ):
17+ """
18+ TCA : Transfer Component Analysis
19+
20+ Parameters
21+ ----------
22+ n_components : int or float (default=None)
23+ Number of components to keep.
24+
25+ mu : float (default=0.1)
26+ Regularization parameter. The larger
27+ ``mu`` is, the less adaptation is performed.
28+
29+ Attributes
30+ ----------
31+ estimator_ : object
32+ Estimator.
33+
34+ Examples
35+ --------
36+ >>> from sklearn.linear_model import RidgeClassifier
37+ >>> from adapt.utils import make_classification_da
38+ >>> from adapt.feature_based import TCA
39+ >>> Xs, ys, Xt, yt = make_classification_da()
40+ >>> model = TCA(RidgeClassifier(), Xt=Xt, n_components=1, mu=0.1,
41+ ... kernel="rbf", gamma=0.1, verbose=0, random_state=0)
42+ >>> model.fit(Xs, ys)
43+ >>> model.score(Xt, yt)
44+ 0.93
45+
46+ See also
47+ --------
48+ CORAL
49+ FA
50+
51+ References
52+ ----------
53+ .. [1] `[1] <https://www.cse.ust.hk/~qyang/Docs/2009/TCA.pdf>` S. J. Pan, \
54+ I. W. Tsang, J. T. Kwok and Q. Yang. "Domain Adaptation via Transfer Component \
55+ Analysis". In IEEE transactions on neural networks 2010
56+ """
57+ def __init__ (self ,
58+ estimator = None ,
59+ Xt = None ,
60+ n_components = 20 ,
61+ mu = 0.1 ,
62+ kernel = "rbf" ,
63+ copy = True ,
64+ verbose = 1 ,
65+ random_state = None ,
66+ ** params ):
67+
68+ names = self ._get_param_names ()
69+ kwargs = {k : v for k , v in locals ().items () if k in names }
70+ kwargs .update (params )
71+ super ().__init__ (** kwargs )
72+
73+
74+ def fit_transform (self , Xs , Xt , ** kwargs ):
75+ """
76+ Fit embeddings.
77+
78+ Parameters
79+ ----------
80+ Xs : array
81+ Input source data.
82+
83+ Xt : array
84+ Input target data.
85+
86+ kwargs : key, value argument
87+ Not used, present here for adapt consistency.
88+
89+ Returns
90+ -------
91+ Xs_emb : embedded source data
92+ """
93+ Xs = check_array (Xs )
94+ Xt = check_array (Xt )
95+ set_random_seed (self .random_state )
96+
97+ self .Xs_ = Xs
98+ self .Xt_ = Xt
99+
100+ n = len (Xs )
101+ m = len (Xt )
102+
103+ # Compute Kernel Matrix K
104+ kernel_params = {k : v for k , v in self .__dict__ .items ()
105+ if k in KERNEL_PARAMS [self .kernel ]}
106+
107+ Kss = pairwise .pairwise_kernels (Xs , Xs , metric = self .kernel , ** kernel_params )
108+ Ktt = pairwise .pairwise_kernels (Xt , Xt , metric = self .kernel , ** kernel_params )
109+ Kst = pairwise .pairwise_kernels (Xs , Xt , metric = self .kernel , ** kernel_params )
110+
111+ K = np .concatenate ((Kss , Kst ), axis = 1 )
112+ K = np .concatenate ((K , np .concatenate ((Kst .transpose (), Ktt ), axis = 1 )), axis = 0 )
113+
114+ # Compute L
115+ Lss = np .ones ((n ,n )) * (1. / (n ** 2 ))
116+ Ltt = np .ones ((m ,m )) * (1. / (m ** 2 ))
117+ Lst = np .ones ((n ,m )) * (- 1. / (n * m ))
118+
119+ L = np .concatenate ((Lss , Lst ), axis = 1 )
120+ L = np .concatenate ((L , np .concatenate ((Lst .transpose (), Ltt ), axis = 1 )), axis = 0 )
121+
122+ # Compute H
123+ H = np .eye (n + m ) - 1 / (n + m ) * np .ones ((n + m , n + m ))
124+
125+ # Compute solution
126+ a = np .eye (n + m ) + self .mu * K .dot (L .dot (K ))
127+ b = K .dot (H .dot (K ))
128+ sol = linalg .lstsq (a , b )[0 ]
129+
130+ values , vectors = linalg .eigh (sol )
131+
132+ args = np .argsort (np .abs (values ))[::- 1 ][:self .n_components ]
133+
134+ self .vectors_ = np .real (vectors [:, args ])
135+
136+ Xs_enc = K .dot (self .vectors_ )[:n ]
137+
138+ return Xs_enc
139+
140+
141+ def transform (self , X , domain = "tgt" ):
142+ """
143+ Return aligned features for X.
144+
145+ Parameters
146+ ----------
147+ X : array
148+ Input data.
149+
150+ domain : str (default="tgt")
151+ Choose between ``"source", "src"`` or
152+ ``"target", "tgt"`` feature embedding.
153+
154+ Returns
155+ -------
156+ X_emb : array
157+ Embeddings of X.
158+ """
159+ X = check_array (X )
160+
161+ kernel_params = {k : v for k , v in self .__dict__ .items ()
162+ if k in KERNEL_PARAMS [self .kernel ]}
163+
164+ Kss = pairwise .pairwise_kernels (X , self .Xs_ , metric = self .kernel , ** kernel_params )
165+ Kst = pairwise .pairwise_kernels (X , self .Xt_ , metric = self .kernel , ** kernel_params )
166+
167+ K = np .concatenate ((Kss , Kst ), axis = 1 )
168+
169+ return K .dot (self .vectors_ )
0 commit comments