Skip to content

Commit 98911d1

Browse files
Merge pull request #77 from antoinedemathelin/master
feat: Add PRED
2 parents 25a2989 + ef64ada commit 98911d1

File tree

3 files changed

+298
-1
lines changed

3 files changed

+298
-1
lines changed

adapt/feature_based/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ._fmmd import fMMD
1616
from ._ccsa import CCSA
1717
from ._tca import TCA
18+
from ._pred import PRED
1819

1920
__all__ = ["FA", "CORAL", "DeepCORAL", "ADDA", "DANN",
20-
"MCD", "MDD", "WDGRL", "CDAN", "SA", "fMMD", "CCSA", "TCA"]
21+
"MCD", "MDD", "WDGRL", "CDAN", "SA", "fMMD", "CCSA", "TCA", "PRED"]

adapt/feature_based/_pred.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""
2+
Frustratingly Easy Domain Adaptation module.
3+
"""
4+
5+
import warnings
6+
7+
import numpy as np
8+
from sklearn.utils import check_array
9+
from sklearn.exceptions import NotFittedError
10+
11+
from adapt.base import BaseAdaptEstimator, make_insert_doc
12+
from adapt.utils import check_arrays, check_estimator
13+
14+
15+
@make_insert_doc(supervised=True)
16+
class PRED(BaseAdaptEstimator):
17+
"""
18+
PRED: Feature Augmentation with SrcOnly Prediction
19+
20+
PRED uses the output of a source pretrain model as a feature in
21+
the target model. Specifically, PRED first trains a
22+
SrcOnly model. Then it runs the SrcOnly model on the target data.
23+
It uses the predictions made by the SrcOnly model as additional features
24+
and trains a second model on the target data, augmented with this new feature.
25+
26+
Parameters
27+
----------
28+
pretrain : bool (default=True)
29+
Weither to pretrain the estimator on
30+
source or not. If False, `estimator`
31+
should be already fitted.
32+
33+
Attributes
34+
----------
35+
estimator_src_ : object
36+
Fitted source estimator.
37+
38+
estimator_ : object
39+
Fitted estimator.
40+
41+
See also
42+
--------
43+
FA
44+
45+
Examples
46+
--------
47+
>>> from sklearn.linear_model import RidgeClassifier
48+
>>> from adapt.utils import make_classification_da
49+
>>> from adapt.feature_based import PRED
50+
>>> Xs, ys, Xt, yt = make_classification_da()
51+
>>> model = PRED(RidgeClassifier(0.), Xt=Xt[[1, -1, -2]], yt=yt[[1, -1, -2]],
52+
... pretrain=True, verbose=0, random_state=0)
53+
>>> model.fit(Xs, ys)
54+
>>> model.score(Xt, yt)
55+
0.77
56+
57+
References
58+
----------
59+
.. [1] `[1] <https://arxiv.org/pdf/0907.1815\
60+
.pdf>`_ Daume III, H. "Frustratingly easy domain adaptation". In ACL, 2007.
61+
"""
62+
def __init__(self,
63+
estimator=None,
64+
Xt=None,
65+
yt=None,
66+
copy=True,
67+
pretrain=True,
68+
verbose=1,
69+
random_state=None,
70+
**params):
71+
72+
names = self._get_param_names()
73+
kwargs = {k: v for k, v in locals().items() if k in names}
74+
kwargs.update(params)
75+
super().__init__(**kwargs)
76+
77+
78+
def fit_transform(self, Xs, Xt, ys, yt, domains=None, **kwargs):
79+
"""
80+
Fit embeddings.
81+
82+
Parameters
83+
----------
84+
Xs : array
85+
Source input data.
86+
87+
Xt : array
88+
Target input data.
89+
90+
ys : array
91+
Source output data.
92+
93+
yt : array
94+
Target output data.
95+
96+
kwargs : key, value argument
97+
Not used, present here for adapt consistency.
98+
99+
Returns
100+
-------
101+
Xt_aug, yt : augmented input and output target data
102+
"""
103+
Xs, ys = check_arrays(Xs, ys)
104+
Xt, yt = check_arrays(Xt, yt)
105+
106+
self.estimators_ = []
107+
108+
if self.pretrain:
109+
estimator = self.fit_estimator(Xs, ys,
110+
warm_start=False,
111+
random_state=self.random_state)
112+
self.estimator_src_ = estimator
113+
del self.estimator_
114+
else:
115+
self.estimator_src_ = check_estimator(self.estimator,
116+
copy=self.copy,
117+
force_copy=True)
118+
119+
yt_pred = self.estimator_src_.predict(Xt)
120+
121+
if len(yt_pred.shape) < 2:
122+
yt_pred = yt_pred.reshape(-1, 1)
123+
124+
X = np.concatenate((Xt, yt_pred), axis=-1)
125+
y = yt
126+
return X, y
127+
128+
129+
def transform(self, X, domain="tgt"):
130+
"""
131+
Return augmented features for X.
132+
133+
If `domain="tgt"`, the prediction of the source model on `X`
134+
are added to `X`.
135+
136+
If `domain="src"`, `X` is returned.
137+
138+
Parameters
139+
----------
140+
X : array
141+
Input data.
142+
143+
domain : str (default="tgt")
144+
Choose between ``"source", "src"`` and
145+
``"target", "tgt"`` feature augmentation.
146+
147+
Returns
148+
-------
149+
X_emb : array
150+
Embeddings of X.
151+
"""
152+
X = check_array(X, allow_nd=True)
153+
154+
if domain in ["tgt", "target"]:
155+
y_pred = self.estimator_src_.predict(X)
156+
if len(y_pred.shape) < 2:
157+
y_pred = y_pred.reshape(-1, 1)
158+
X_emb = np.concatenate((X, y_pred), axis=-1)
159+
elif domain in ["src", "source"]:
160+
X_emb = X
161+
else:
162+
raise ValueError("`domain `argument "
163+
"should be `tgt` or `src`, "
164+
"got, %s"%domain)
165+
return X_emb
166+
167+
168+
def predict(self, X, domain=None, **predict_params):
169+
"""
170+
Return estimator predictions after
171+
adaptation.
172+
173+
If `domain="tgt"`, the input feature ``X`` are first transformed.
174+
Then the ``predict`` method of the fitted estimator
175+
``estimator_`` is applied on the transformed ``X``.
176+
177+
If `domain="src"`, ``estimator_src_`` is applied direclty
178+
on ``X``.
179+
180+
Parameters
181+
----------
182+
X : array
183+
input data
184+
185+
domain : str (default=None)
186+
For antisymetric feature-based method,
187+
different transformation of the input X
188+
are applied for different domains. The domain
189+
should then be specified between "src" and "tgt".
190+
If ``None`` the default transformation is the
191+
target one.
192+
193+
Returns
194+
-------
195+
y_pred : array
196+
prediction of the Adapt Model.
197+
"""
198+
X = check_array(X, ensure_2d=True, allow_nd=True, accept_sparse=True)
199+
if domain is None:
200+
domain = "tgt"
201+
X = self.transform(X, domain=domain)
202+
203+
if domain in ["tgt", "target"]:
204+
return self.estimator_.predict(X, **predict_params)
205+
else:
206+
return self.estimator_src_.predict(X, **predict_params)
207+
208+
209+
def score(self, X, y, sample_weight=None, domain=None):
210+
"""
211+
Return the estimator score.
212+
213+
If `domain="tgt"`, the input feature ``X`` are first transformed.
214+
Then the ``score`` method of the fitted estimator
215+
``estimator_`` is applied on the transformed ``X``.
216+
217+
If `domain="src"`, ``estimator_src_`` is applied direclty
218+
on ``X``.
219+
220+
Parameters
221+
----------
222+
X : array
223+
input data
224+
225+
y : array
226+
output data
227+
228+
sample_weight : array (default=None)
229+
Sample weights
230+
231+
domain : str (default=None)
232+
This parameter specifies for antisymetric
233+
feature-based method which transformation
234+
will be applied between "source" and "target".
235+
If ``None`` the transformation by default is
236+
the target one.
237+
238+
Returns
239+
-------
240+
score : float
241+
estimator score.
242+
"""
243+
X, y = check_arrays(X, y, accept_sparse=True)
244+
245+
if domain is None:
246+
domain = "tgt"
247+
X = self.transform(X, domain=domain)
248+
249+
if domain in ["tgt", "target"]:
250+
estimator = self.estimator_
251+
else:
252+
estimator = self.estimator_src_
253+
254+
if hasattr(estimator, "score"):
255+
score = estimator.score(X, y, sample_weight)
256+
elif hasattr(estimator, "evaluate"):
257+
if np.prod(X.shape) <= 10**8:
258+
score = estimator.evaluate(
259+
X, y,
260+
sample_weight=sample_weight,
261+
batch_size=len(X)
262+
)
263+
else:
264+
score = estimator.evaluate(
265+
X, y,
266+
sample_weight=sample_weight
267+
)
268+
if isinstance(score, (tuple, list)):
269+
score = score[0]
270+
else:
271+
raise ValueError("Estimator does not implement"
272+
" score or evaluate method")
273+
return score

tests/test_pred.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from sklearn.linear_model import RidgeClassifier
2+
from adapt.utils import make_classification_da
3+
from adapt.feature_based import PRED
4+
5+
Xs, ys, Xt, yt = make_classification_da()
6+
7+
def test_pred():
8+
model = PRED(RidgeClassifier(), pretrain=True, Xt=Xt[:3], yt=yt[:3],
9+
verbose=0, random_state=0)
10+
model.fit(Xs, ys)
11+
model.predict(Xt)
12+
model.predict(Xt, "src")
13+
model.score(Xt, yt, domain="src")
14+
model.score(Xt, yt, domain="tgt")
15+
16+
model = PRED(RidgeClassifier().fit(Xs, ys),
17+
pretrain=False, Xt=Xt[:3], yt=yt[:3],
18+
verbose=0, random_state=0)
19+
model.fit(Xs, ys)
20+
model.predict(Xt)
21+
model.predict(Xt, "src")
22+
model.score(Xt, yt, domain="src")
23+
model.score(Xt, yt, domain="tgt")

0 commit comments

Comments
 (0)