Skip to content

Commit 770f31a

Browse files
Merge pull request #78 from antoinedemathelin/master
feat: Add LinInt
2 parents 98911d1 + 8b5deb4 commit 770f31a

File tree

4 files changed

+207
-2
lines changed

4 files changed

+207
-2
lines changed

adapt/feature_based/_pred.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self,
7575
super().__init__(**kwargs)
7676

7777

78-
def fit_transform(self, Xs, Xt, ys, yt, domains=None, **kwargs):
78+
def fit_transform(self, Xs, Xt, ys, yt, **kwargs):
7979
"""
8080
Fit embeddings.
8181

adapt/parameter_based/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from ._finetuning import FineTuning
77
from ._transfer_tree import TransferTreeClassifier
88
from ._transfer_tree import TransferForestClassifier
9+
from ._linint import LinInt
910

1011
__all__ = ["RegularTransferLR",
1112
"RegularTransferLC",
1213
"RegularTransferNN",
1314
"FineTuning",
1415
"TransferTreeClassifier",
15-
"TransferForestClassifier"]
16+
"TransferForestClassifier",
17+
"LinInt"]

adapt/parameter_based/_linint.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
Frustratingly Easy Domain Adaptation module.
3+
"""
4+
5+
import warnings
6+
7+
import numpy as np
8+
from sklearn.utils import check_array
9+
from sklearn.exceptions import NotFittedError
10+
from sklearn.linear_model import LinearRegression
11+
from sklearn.metrics import r2_score
12+
13+
from adapt.base import BaseAdaptEstimator, make_insert_doc
14+
from adapt.utils import check_arrays, set_random_seed
15+
16+
17+
@make_insert_doc(supervised=True)
18+
class LinInt(BaseAdaptEstimator):
19+
"""
20+
LinInt: Linear Interpolation between SrcOnly and TgtOnly.
21+
22+
LinInt linearly interpolates the predictions of the SrcOnly and
23+
TgtOnly models. The interpolation parameter is adjusted based on
24+
a small amount of target data removed from the training set
25+
of TgtOnly.
26+
27+
Parameters
28+
----------
29+
prop : float (default=0.5)
30+
Proportion between 0 and 1 of the data used
31+
to fit the TgtOnly model. The rest of the
32+
target data are used to estimate the interpolation
33+
parameter.
34+
35+
Attributes
36+
----------
37+
estimator_src_ : object
38+
Fitted source estimator.
39+
40+
estimator_ : object
41+
Fitted estimator.
42+
43+
See also
44+
--------
45+
FA
46+
PRED
47+
48+
Examples
49+
--------
50+
>>> from sklearn.linear_model import Ridge
51+
>>> from adapt.utils import make_regression_da
52+
>>> from adapt.parameter_based import LinInt
53+
>>> Xs, ys, Xt, yt = make_regression_da()
54+
>>> model = LinInt(Ridge(), Xt=Xt[:6], yt=yt[:6], prop=0.5,
55+
... verbose=0, random_state=0)
56+
>>> model.fit(Xs, ys)
57+
>>> model.score(Xt, yt)
58+
0.68...
59+
60+
References
61+
----------
62+
.. [1] `[1] <https://arxiv.org/pdf/0907.1815\
63+
.pdf>`_ Daume III, H. "Frustratingly easy domain adaptation". In ACL, 2007.
64+
"""
65+
def __init__(self,
66+
estimator=None,
67+
Xt=None,
68+
yt=None,
69+
prop=0.5,
70+
copy=True,
71+
verbose=1,
72+
random_state=None,
73+
**params):
74+
75+
names = self._get_param_names()
76+
kwargs = {k: v for k, v in locals().items() if k in names}
77+
kwargs.update(params)
78+
super().__init__(**kwargs)
79+
80+
81+
def fit(self, Xs, ys, Xt=None, yt=None, **kwargs):
82+
"""
83+
Fit LinInt.
84+
85+
Parameters
86+
----------
87+
Xs : array
88+
Source input data.
89+
90+
ys : array
91+
Source output data.
92+
93+
Xt : array
94+
Target input data.
95+
96+
yt : array
97+
Target output data.
98+
99+
kwargs : key, value argument
100+
Not used, present here for adapt consistency.
101+
102+
Returns
103+
-------
104+
Xt_aug, yt : augmented input and output target data
105+
"""
106+
set_random_seed(self.random_state)
107+
108+
Xs, ys = check_arrays(Xs, ys, accept_sparse=True)
109+
Xt, yt = self._get_target_data(Xt, yt)
110+
Xt, yt = check_arrays(Xt, yt, accept_sparse=True)
111+
112+
shuffle_index = np.random.choice(len(Xt), len(Xt), replace=False)
113+
cut = int(len(Xt)*self.prop)
114+
Xt_train = Xt[shuffle_index[:cut]]
115+
Xt_test = Xt[shuffle_index[cut:]]
116+
yt_train = yt[shuffle_index[:cut]]
117+
yt_test = yt[shuffle_index[cut:]]
118+
119+
self.estimator_src_ = self.fit_estimator(Xs, ys,
120+
warm_start=False,
121+
random_state=None)
122+
123+
self.estimator_ = self.fit_estimator(Xt_train, yt_train,
124+
warm_start=False,
125+
random_state=None)
126+
127+
self.interpolator_ = LinearRegression(fit_intercept=False)
128+
129+
yp_src = self.estimator_src_.predict(Xt_test)
130+
yp_tgt = self.estimator_.predict(Xt_test)
131+
132+
if len(yp_src.shape) < 2:
133+
yp_src = yp_src.reshape(-1, 1)
134+
if len(yp_tgt.shape) < 2:
135+
yp_tgt = yp_tgt.reshape(-1, 1)
136+
137+
Xp = np.concatenate((yp_src, yp_tgt), axis=1)
138+
139+
self.interpolator_.fit(Xp, yt_test)
140+
141+
return self
142+
143+
144+
def predict(self, X):
145+
"""
146+
Return LinInt predictions.
147+
148+
Parameters
149+
----------
150+
X : array
151+
Input data.
152+
153+
Returns
154+
-------
155+
y : array
156+
Predictions
157+
"""
158+
yp_src = self.estimator_src_.predict(X)
159+
yp_tgt = self.estimator_.predict(X)
160+
161+
if len(yp_src.shape) < 2:
162+
yp_src = yp_src.reshape(-1, 1)
163+
if len(yp_tgt.shape) < 2:
164+
yp_tgt = yp_tgt.reshape(-1, 1)
165+
166+
Xp = np.concatenate((yp_src, yp_tgt), axis=1)
167+
168+
return self.interpolator_.predict(Xp)
169+
170+
171+
def score(self, X, y):
172+
"""
173+
Compute R2 score
174+
175+
Parameters
176+
----------
177+
X : array
178+
input data
179+
180+
y : array
181+
output data
182+
183+
Returns
184+
-------
185+
score : float
186+
estimator score.
187+
"""
188+
yp = self.predict(X)
189+
score = r2_score(y, yp)
190+
return score

tests/test_linint.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from sklearn.linear_model import Ridge
2+
from adapt.utils import make_regression_da
3+
from adapt.parameter_based import LinInt
4+
5+
Xs, ys, Xt, yt = make_regression_da()
6+
7+
def test_linint():
8+
model = LinInt(Ridge(), Xt=Xt[:6], yt=yt[:6],
9+
verbose=0, random_state=0)
10+
model.fit(Xs, ys)
11+
model.fit(Xs, ys, Xt[:6], yt[:6])
12+
model.predict(Xt)
13+
model.score(Xt, yt)

0 commit comments

Comments
 (0)