Skip to content

Commit c336b3c

Browse files
Add balanced weighting
1 parent 529087b commit c336b3c

File tree

4 files changed

+166
-9
lines changed

4 files changed

+166
-9
lines changed

adapt/instance_based/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from ._wann import WANN
99
from ._ldm import LDM
1010
from ._nearestneighborsweighting import NearestNeighborsWeighting
11+
from ._balancedweighting import BalancedWeighting
1112

1213
__all__ = ["LDM", "KLIEP", "KMM", "TrAdaBoost", "TrAdaBoostR2",
13-
"TwoStageTrAdaBoostR2", "WANN", "NearestNeighborsWeighting"]
14+
"TwoStageTrAdaBoostR2", "WANN", "NearestNeighborsWeighting",
15+
"BalancedWeighting"]
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import numpy as np
2+
from sklearn.base import check_array
3+
4+
from adapt.base import BaseAdaptEstimator, make_insert_doc
5+
from adapt.utils import set_random_seed
6+
7+
8+
@make_insert_doc(supervised=True)
9+
class BalancedWeighting(BaseAdaptEstimator):
10+
"""
11+
BW : Balanced Weighting
12+
13+
Fit the estimator on source and target labeled data
14+
according to the modified loss:
15+
16+
.. math::
17+
18+
\min_{h} (1-\gamma) \mathcal{L}(h(X_S), y_S) + \gamma \mathcal{L}(h(X_T), y_T)
19+
20+
Where:
21+
22+
- :math:`(X_S, y_S), (X_T, y_T)` are respectively the labeled source
23+
and target data.
24+
- :math:`\mathcal{L}` is the estimator loss
25+
- :math:`\gamma` is the ratio parameter
26+
27+
Parameters
28+
----------
29+
gamma : float
30+
ratio between 0 and 1 correspond to the importance
31+
given to the target labeled data. When `ratio=1`, the
32+
estimator is only fitted on target data. `ratio=0.5`
33+
corresponds to a balanced training.
34+
35+
Attributes
36+
----------
37+
weights_ : numpy array
38+
Training instance weights.
39+
40+
estimator_ : object
41+
Estimator.
42+
43+
Examples
44+
--------
45+
>>> from sklearn.linear_model import RidgeClassifier
46+
>>> from adapt.utils import make_classification_da
47+
>>> from adapt.instance_based import BalancedWeighting
48+
>>> Xs, ys, Xt, yt = make_classification_da()
49+
>>> model = BalancedWeighting(RidgeClassifier(), gamma=0.5, Xt=Xt[:3], yt=yt[:3],
50+
... verbose=0, random_state=0)
51+
>>> model.fit(Xs, ys)
52+
>>> model.score(Xt, yt)
53+
0.93
54+
55+
See also
56+
--------
57+
TrAdaBoost
58+
TrAdaBoostR2
59+
WANN
60+
61+
References
62+
----------
63+
.. [1] `[1] <https://openreview.net/forum?id=SybwYsbdWH>`_ P. Wu, T. G. Dietterich. \
64+
"Improving SVM accuracy by training on auxiliary data sources". In ICML 2004
65+
"""
66+
def __init__(self,
67+
estimator=None,
68+
Xt=None,
69+
yt=None,
70+
gamma=0.5,
71+
copy=True,
72+
verbose=1,
73+
random_state=None,
74+
**params):
75+
76+
names = self._get_param_names()
77+
kwargs = {k: v for k, v in locals().items() if k in names}
78+
kwargs.update(params)
79+
super().__init__(**kwargs)
80+
81+
82+
def fit_weights(self, Xs, Xt, ys, yt, **kwargs):
83+
"""
84+
Fit importance weighting.
85+
86+
Parameters
87+
----------
88+
Xs : array
89+
Input source data.
90+
91+
Xt : array
92+
Input target data.
93+
94+
ys : array
95+
Source labels.
96+
97+
yt : array
98+
Target labels.
99+
100+
kwargs : key, value argument
101+
Not used, present here for adapt consistency.
102+
103+
Returns
104+
-------
105+
weights_ : sample weights
106+
107+
X : concatenation of Xs and Xt
108+
109+
y : concatenation of ys and yt
110+
"""
111+
Xs = check_array(Xs)
112+
Xt = check_array(Xt)
113+
set_random_seed(self.random_state)
114+
115+
X = np.concatenate((Xs, Xt))
116+
y = np.concatenate((ys, yt))
117+
118+
src_weights = np.ones(Xs.shape[0]) * Xt.shape[0] * (1-self.gamma)
119+
tgt_weights = np.ones(Xt.shape[0]) * Xs.shape[0] * self.gamma
120+
121+
self.weights_ = np.concatenate((src_weights, tgt_weights))
122+
self.weights_ /= np.mean(self.weights_)
123+
124+
return self.weights_, X, y
125+
126+
127+
def predict_weights(self):
128+
"""
129+
Return fitted source weights
130+
131+
Returns
132+
-------
133+
weights_ : sample weights
134+
"""
135+
if hasattr(self, "weights_"):
136+
return self.weights_
137+
else:
138+
raise NotFittedError("Weights are not fitted yet, please "
139+
"call 'fit_weights' or 'fit' first.")

adapt/instance_based/_tradaboost.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,6 @@ def fit(self, X, y, Xt=None, yt=None,
212212
fit_params : key, value arguments
213213
Arguments given to the fit method of the
214214
estimator.
215-
216-
Other Parameters
217-
----------------
218-
Xt : array (default=self.Xt)
219-
Target input data.
220-
221-
yt : array (default=self.yt)
222-
Target output data.
223215
224216
Returns
225217
-------

tests/test_balancedweighting.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from sklearn.linear_model import RidgeClassifier
2+
from adapt.utils import make_classification_da
3+
from adapt.instance_based import BalancedWeighting
4+
5+
Xs, ys, Xt, yt = make_classification_da()
6+
7+
def test_good_ratio():
8+
model = BalancedWeighting(RidgeClassifier(), gamma=0.5, Xt=Xt[:3], yt=yt[:3],
9+
verbose=0, random_state=0)
10+
model.fit(Xs, ys)
11+
model.predict(Xt)
12+
assert model.score(Xt, yt) > 0.9
13+
14+
15+
def test_bad_ratio():
16+
model = BalancedWeighting(RidgeClassifier(), gamma=0.99, Xt=Xt[:3], yt=yt[:3],
17+
verbose=0, random_state=0)
18+
model.fit(Xs, ys)
19+
assert model.score(Xt, yt) < 0.7
20+
21+
model = BalancedWeighting(RidgeClassifier(), gamma=0.01, Xt=Xt[:3], yt=yt[:3],
22+
verbose=0, random_state=0)
23+
model.fit(Xs, ys)
24+
assert model.score(Xt, yt) < 0.9

0 commit comments

Comments
 (0)