1+ """
2+ IWC
3+ """
4+
5+ import inspect
6+
7+ import numpy as np
8+ from sklearn .utils import check_array
9+ from sklearn .linear_model import LogisticRegression
10+ from sklearn .base import BaseEstimator
11+ from sklearn .exceptions import NotFittedError
12+
13+ from adapt .base import BaseAdaptEstimator , make_insert_doc
14+ from adapt .utils import check_arrays , set_random_seed , check_estimator
15+
16+ EPS = np .finfo (float ).eps
17+
18+
19+ @make_insert_doc ()
20+ class IWC (BaseAdaptEstimator ):
21+ """
22+ IWC: Importance Weighting Classifier
23+
24+ Importance weighting based on the output of a domain classifier
25+ which discriminates between source and target data.
26+
27+ The source importance weighting are given with the following formula:
28+
29+ .. math::
30+
31+ w(x) = \f rac{1}{P(x \in Source)} - 1
32+
33+ Parameters
34+ ----------
35+ classifier : object (default=None)
36+ Binary classifier trained to discriminate
37+ between source and target data.
38+
39+ cl_params : dict (default=None)
40+ Dictionnary of parameters that will
41+ be given in the `fit` and/or `compile` methods
42+ of the classifier.
43+
44+ Attributes
45+ ----------
46+ classifier_ : object
47+ Fitted classifier.
48+
49+ estimator_ : object
50+ Fitted estimator.
51+
52+ See also
53+ --------
54+ NearestNeighborsWeighting
55+ IWN
56+
57+ Examples
58+ --------
59+ >>> from sklearn.linear_model import RidgeClassifier
60+ >>> from adapt.utils import make_classification_da
61+ >>> from adapt.instance_based import IWC
62+ >>> Xs, ys, Xt, yt = make_classification_da()
63+ >>> model = IWC(RidgeClassifier(0.), classifier=RidgeClassifier(0.),
64+ ... Xt=Xt, random_state=0)
65+ >>> model.fit(Xs, ys);
66+ >>> model.score(Xt, yt)
67+ 0.74
68+
69+ References
70+ ----------
71+ .. [1] `[1] <https://icml.cc/imls/conferences/2007/proceedings/papers/303.pdf>`_ \
72+ Steffen Bickel, Michael Bruckner, Tobias Scheffer. "Discriminative Learning for Differing \
73+ Training and Test Distributions". In ICML 2007
74+ """
75+ def __init__ (self ,
76+ estimator = None ,
77+ Xt = None ,
78+ yt = None ,
79+ classifier = None ,
80+ cl_params = None ,
81+ copy = True ,
82+ verbose = 1 ,
83+ random_state = None ,
84+ ** params ):
85+
86+ names = self ._get_param_names ()
87+ kwargs = {k : v for k , v in locals ().items () if k in names }
88+ kwargs .update (params )
89+ super ().__init__ (** kwargs )
90+
91+
92+ def fit_weights (self , Xs , Xt , warm_start = False , ** kwargs ):
93+ """
94+ Fit importance weighting.
95+
96+ Parameters
97+ ----------
98+ Xs : array
99+ Input source data.
100+
101+ Xt : array
102+ Input target data.
103+
104+ warm_start : bool (default=False)
105+ Weither to train the domain classifier
106+ from scratch or not.
107+ If False, the classifier is trained from scratch.
108+
109+ kwargs : key, value argument
110+ Not used, present here for adapt consistency.
111+
112+ Returns
113+ -------
114+ weights_ : sample weights
115+ """
116+ Xs = check_array (Xs )
117+ Xt = check_array (Xt )
118+ set_random_seed (self .random_state )
119+
120+ if self .cl_params is None :
121+ self .cl_params_ = {}
122+ else :
123+ self .cl_params_ = self .cl_params
124+
125+ if (not warm_start ) or (not hasattr (self , "classifier_" )):
126+ if self .classifier is None :
127+ self .classifier_ = LogisticRegression (penalty = "none" )
128+ else :
129+ self .classifier_ = check_estimator (self .classifier ,
130+ copy = True ,
131+ force_copy = True )
132+
133+ if hasattr (self .classifier_ , "compile" ):
134+ args = [
135+ p .name
136+ for p in inspect .signature (self .classifier_ .compile ).parameters .values ()
137+ if p .name != "self" and p .kind != p .VAR_KEYWORD
138+ ]
139+ compile_params = {}
140+ for key , value in self .cl_params_ .items ():
141+ if key in args :
142+ compile_params [key ] = value
143+ self .classifier_ .compile (** compile_params )
144+
145+ args = [
146+ p .name
147+ for p in inspect .signature (self .classifier_ .fit ).parameters .values ()
148+ if p .name != "self" and p .kind != p .VAR_KEYWORD
149+ ]
150+ fit_params = {}
151+ for key , value in self .cl_params_ .items ():
152+ if key in args :
153+ fit_params [key ] = value
154+
155+ X = np .concatenate ((Xs , Xt ))
156+ y = np .concatenate ((np .ones (Xs .shape [0 ]), np .zeros (Xt .shape [0 ])))
157+ shuffle_index = np .random .choice (len (X ), len (X ), replace = False )
158+ X = X [shuffle_index ]
159+ y = y [shuffle_index ]
160+
161+ self .classifier_ .fit (X , y , ** fit_params )
162+
163+ if isinstance (self .classifier_ , BaseEstimator ):
164+ if hasattr (self .classifier_ , "predict_proba" ):
165+ y_pred = self .classifier_ .predict_proba (Xs )[:, 1 ]
166+ elif hasattr (self .classifier_ , "_predict_proba_lr" ):
167+ y_pred = self .classifier_ ._predict_proba_lr (Xs )[:, 1 ]
168+ else :
169+ y_pred = self .classifier_ .predict (Xs ).ravel ()
170+ else :
171+ y_pred = self .classifier_ .predict (Xs ).ravel ()
172+
173+ self .weights_ = 1. / (y_pred + EPS ) - 1.
174+
175+ return self .weights_
176+
177+
178+ def predict_weights (self , X = None ):
179+ """
180+ Return fitted source weights
181+
182+ If ``None``, the fitted source weights are returned.
183+ Else, sample weights are computing using the fitted
184+ ``classifier_``.
185+
186+ Parameters
187+ ----------
188+ X : array (default=None)
189+ Input data.
190+
191+ Returns
192+ -------
193+ weights_ : sample weights
194+ """
195+ if hasattr (self , "weights_" ):
196+ if X is None :
197+ return self .weights_
198+ else :
199+ X = check_array (X )
200+ if isinstance (self .classifier_ , BaseEstimator ):
201+ if hasattr (self .classifier_ , "predict_proba" ):
202+ y_pred = self .classifier_ .predict_proba (X )[:, 1 ]
203+ elif hasattr (self .classifier_ , "_predict_proba_lr" ):
204+ y_pred = self .classifier_ ._predict_proba_lr (X )[:, 1 ]
205+ else :
206+ y_pred = self .classifier_ .predict (X ).ravel ()
207+ else :
208+ y_pred = self .classifier_ .predict (X ).ravel ()
209+ weights = 1. / (y_pred + EPS ) - 1.
210+ return weights
211+ else :
212+ raise NotFittedError ("Weights are not fitted yet, please "
213+ "call 'fit_weights' or 'fit' first." )
0 commit comments