1+ import numpy as np
2+ from sklearn .base import check_array
3+
4+ from adapt .base import BaseAdaptEstimator , make_insert_doc
5+ from adapt .utils import set_random_seed
6+
7+
8+ @make_insert_doc (supervised = True )
9+ class BalancedWeighting (BaseAdaptEstimator ):
10+ """
11+ BW : Balanced Weighting
12+
13+ Fit the estimator on source and target labeled data
14+ according to the modified loss:
15+
16+ .. math::
17+
18+ \min_{h} (1-\gamma) \mathcal{L}(h(X_S), y_S) + \gamma \mathcal{L}(h(X_T), y_T)
19+
20+ Where:
21+
22+ - :math:`(X_S, y_S), (X_T, y_T)` are respectively the labeled source
23+ and target data.
24+ - :math:`\mathcal{L}` is the estimator loss
25+ - :math:`\gamma` is the ratio parameter
26+
27+ Parameters
28+ ----------
29+ gamma : float
30+ ratio between 0 and 1 correspond to the importance
31+ given to the target labeled data. When `ratio=1`, the
32+ estimator is only fitted on target data. `ratio=0.5`
33+ corresponds to a balanced training.
34+
35+ Attributes
36+ ----------
37+ weights_ : numpy array
38+ Training instance weights.
39+
40+ estimator_ : object
41+ Estimator.
42+
43+ Examples
44+ --------
45+ >>> from sklearn.linear_model import RidgeClassifier
46+ >>> from adapt.utils import make_classification_da
47+ >>> from adapt.instance_based import BalancedWeighting
48+ >>> Xs, ys, Xt, yt = make_classification_da()
49+ >>> model = BalancedWeighting(RidgeClassifier(), gamma=0.5, Xt=Xt[:3], yt=yt[:3],
50+ ... verbose=0, random_state=0)
51+ >>> model.fit(Xs, ys)
52+ >>> model.score(Xt, yt)
53+ 0.93
54+
55+ See also
56+ --------
57+ TrAdaBoost
58+ TrAdaBoostR2
59+ WANN
60+
61+ References
62+ ----------
63+ .. [1] `[1] <https://openreview.net/forum?id=SybwYsbdWH>`_ P. Wu, T. G. Dietterich. \
64+ "Improving SVM accuracy by training on auxiliary data sources". In ICML 2004
65+ """
66+ def __init__ (self ,
67+ estimator = None ,
68+ Xt = None ,
69+ yt = None ,
70+ gamma = 0.5 ,
71+ copy = True ,
72+ verbose = 1 ,
73+ random_state = None ,
74+ ** params ):
75+
76+ names = self ._get_param_names ()
77+ kwargs = {k : v for k , v in locals ().items () if k in names }
78+ kwargs .update (params )
79+ super ().__init__ (** kwargs )
80+
81+
82+ def fit_weights (self , Xs , Xt , ys , yt , ** kwargs ):
83+ """
84+ Fit importance weighting.
85+
86+ Parameters
87+ ----------
88+ Xs : array
89+ Input source data.
90+
91+ Xt : array
92+ Input target data.
93+
94+ ys : array
95+ Source labels.
96+
97+ yt : array
98+ Target labels.
99+
100+ kwargs : key, value argument
101+ Not used, present here for adapt consistency.
102+
103+ Returns
104+ -------
105+ weights_ : sample weights
106+
107+ X : concatenation of Xs and Xt
108+
109+ y : concatenation of ys and yt
110+ """
111+ Xs = check_array (Xs )
112+ Xt = check_array (Xt )
113+ set_random_seed (self .random_state )
114+
115+ X = np .concatenate ((Xs , Xt ))
116+ y = np .concatenate ((ys , yt ))
117+
118+ src_weights = np .ones (Xs .shape [0 ]) * Xt .shape [0 ] * (1 - self .gamma )
119+ tgt_weights = np .ones (Xt .shape [0 ]) * Xs .shape [0 ] * self .gamma
120+
121+ self .weights_ = np .concatenate ((src_weights , tgt_weights ))
122+ self .weights_ /= np .mean (self .weights_ )
123+
124+ return self .weights_ , X , y
125+
126+
127+ def predict_weights (self ):
128+ """
129+ Return fitted source weights
130+
131+ Returns
132+ -------
133+ weights_ : sample weights
134+ """
135+ if hasattr (self , "weights_" ):
136+ return self .weights_
137+ else :
138+ raise NotFittedError ("Weights are not fitted yet, please "
139+ "call 'fit_weights' or 'fit' first." )
0 commit comments