1+ """
2+ Weighting Adversarial Neural Network (WANN)
3+ """
4+ from copy import deepcopy
5+
6+ import numpy as np
7+ import tensorflow as tf
8+ from tensorflow .keras import Sequential , Model
9+ from tensorflow .keras .layers import Layer , multiply
10+ from tensorflow .keras .callbacks import Callback
11+ from tensorflow .keras .constraints import MaxNorm
12+
13+ from adapt .utils import (GradientHandler ,
14+ check_arrays ,
15+ check_one_array ,
16+ check_network ,
17+ get_default_task )
18+ from adapt .feature_based import BaseDeepFeature
19+
20+
21+ class StopTraining (Callback ):
22+
23+ def on_train_batch_end (self , batch , logs = {}):
24+ if logs .get ('loss' ) < 0.01 :
25+ print ("Weights initialization succeeded !" )
26+ self .model .stop_training = True
27+
28+
29+ class WANN (BaseDeepFeature ):
30+ """
31+ WANN: Weighting Adversarial Neural Network is an instance-based domain adaptation
32+ method suited for regression tasks. It supposes the supervised setting where some
33+ labeled target data are available.
34+
35+ The goal of WANN is to compute a source instances reweighting which correct
36+ "shifts" between source and target domain. This is done by minimizing the
37+ Y-discrepancy distance between source and target distributions
38+
39+ WANN involves three networks:
40+ - the weighting network which learns the source weights.
41+ - the task network which learns the task.
42+ - the discrepancy network which is used to estimate a distance
43+ between the reweighted source and target distributions: the Y-discrepancy
44+
45+ Parameters
46+ ----------
47+ task : tensorflow Model (default=None)
48+ Task netwok. If ``None``, a two layers network with 10
49+ neurons per layer and ReLU activation is used as task network.
50+
51+ weighter : tensorflow Model (default=None)
52+ Encoder netwok. If ``None``, a two layers network with 10
53+ neurons per layer and ReLU activation is used as
54+ weighter network.
55+
56+ C : float (default=1.)
57+ Clipping constant for the weighting networks
58+ regularization. Low value of ``C`` produce smoother
59+ weighting map. If ``C<=0``, No regularization is added.
60+
61+ init_weights : bool (default=True)
62+ If True a pretraining of ``weighter`` is made such
63+ that all predicted weights start close to one.
64+
65+ loss : string or tensorflow loss (default="mse")
66+ Loss function used for the task.
67+
68+ metrics : dict or list of string or tensorflow metrics (default=None)
69+ Metrics given to the model. If a list is provided,
70+ metrics are used on both ``task`` and ``discriminator``
71+ outputs. To give seperated metrics, please provide a
72+ dict of metrics list with ``"task"`` and ``"disc"`` as keys.
73+
74+ optimizer : string or tensorflow optimizer (default=None)
75+ Optimizer of the model. If ``None``, the
76+ optimizer is set to tf.keras.optimizers.Adam(0.001)
77+
78+ copy : boolean (default=True)
79+ Whether to make a copy of ``encoder``, ``task`` and
80+ ``discriminator`` or not.
81+
82+ random_state : int (default=None)
83+ Seed of random generator.
84+ """
85+
86+ def __init__ (self ,
87+ task = None ,
88+ weighter = None ,
89+ C = 1. ,
90+ init_weights = True ,
91+ loss = "mse" ,
92+ metrics = None ,
93+ optimizer = None ,
94+ copy = True ,
95+ random_state = None ):
96+
97+ super ().__init__ (weighter , task , None ,
98+ loss , metrics , optimizer , copy ,
99+ random_state )
100+
101+ self .init_weights = init_weights
102+ self .init_weights_ = init_weights
103+ self .C = C
104+
105+ if weighter is None :
106+ self .weighter_ = get_default_task () #activation="relu"
107+ else :
108+ self .weighter_ = self .encoder_
109+
110+ if self .C > 0. :
111+ self ._add_regularization ()
112+
113+ self .discriminator_ = check_network (self .task_ ,
114+ copy = True ,
115+ display_name = "task" ,
116+ force_copy = True )
117+ self .discriminator_ ._name = self .discriminator_ ._name + "_2"
118+
119+
120+ def _add_regularization (self ):
121+ for layer in self .weighter_ .layers :
122+ if hasattr (self .weighter_ , "kernel_constraint" ):
123+ self .weighter_ .kernel_constraint = MaxNorm (self .C )
124+ if hasattr (self .weighter_ , "bias_constraint" ):
125+ self .weighter_ .bias_constraint = MaxNorm (self .C )
126+
127+
128+ def fit (self , Xs , ys , Xt , yt , ** fit_params ):
129+ Xs , ys , Xt , yt = check_arrays (Xs , ys , Xt , yt )
130+
131+ if self .init_weights_ :
132+ self ._init_weighter (Xs )
133+ self .init_weights_ = False
134+ self ._fit (Xs , ys , Xt , yt , ** fit_params )
135+ return self
136+
137+
138+ def _init_weighter (self , Xs ):
139+ self .weighter_ .compile (optimizer = deepcopy (self .optimizer ), loss = "mse" )
140+ batch_size = 64
141+ epochs = max (1 , int (64 * 1000 / len (Xs )))
142+ callback = StopTraining ()
143+ self .weighter_ .fit (Xs , np .ones (len (Xs )),
144+ epochs = epochs , batch_size = batch_size ,
145+ callbacks = [callback ], verbose = 0 )
146+
147+
148+ def _initialize_networks (self , shape_Xt ):
149+ self .weighter_ .predict (np .zeros ((1 ,) + shape_Xt ));
150+ self .task_ .predict (np .zeros ((1 ,) + shape_Xt ));
151+ self .discriminator_ .predict (np .zeros ((1 ,) + shape_Xt ));
152+
153+
154+ def create_model (self , inputs_Xs , inputs_Xt ):
155+
156+ Flip = GradientHandler (- 1. )
157+
158+ # Get networks output for both source and target
159+ weights_s = self .weighter_ (inputs_Xs )
160+ weights_s = tf .math .abs (weights_s )
161+ task_s = self .task_ (inputs_Xs )
162+ task_t = self .task_ (inputs_Xt )
163+ disc_s = self .discriminator_ (inputs_Xs )
164+ disc_t = self .discriminator_ (inputs_Xt )
165+
166+ # Reversal layer at the end of discriminator
167+ disc_s = Flip (disc_s )
168+ disc_t = Flip (disc_t )
169+
170+ return dict (task_s = task_s , task_t = task_t ,
171+ disc_s = disc_s , disc_t = disc_t ,
172+ weights_s = weights_s )
173+
174+
175+ def get_loss (self , inputs_ys , inputs_yt , task_s ,
176+ task_t , disc_s , disc_t , weights_s ):
177+
178+ loss_task_s = self .loss_ (inputs_ys , task_s )
179+ loss_task_s = multiply ([weights_s , loss_task_s ])
180+
181+ loss_disc_s = self .loss_ (inputs_ys , disc_s )
182+ loss_disc_s = multiply ([weights_s , loss_disc_s ])
183+
184+ loss_disc_t = self .loss_ (inputs_yt , disc_t )
185+
186+ loss_disc = (tf .reduce_mean (loss_disc_t ) -
187+ tf .reduce_mean (loss_disc_s ))
188+
189+ loss = tf .reduce_mean (loss_task_s ) + loss_disc
190+ return loss
191+
192+
193+ def get_metrics (self , inputs_ys , inputs_yt , task_s ,
194+ task_t , disc_s , disc_t , weights_s ):
195+
196+ metrics = {}
197+
198+ loss_s = self .loss_ (inputs_ys , task_s )
199+ loss_t = self .loss_ (inputs_yt , task_t )
200+
201+ metrics ["task_s" ] = tf .reduce_mean (loss_s )
202+ metrics ["task_t" ] = tf .reduce_mean (loss_t )
203+
204+ names_task , names_disc = self ._get_metric_names ()
205+
206+ for metric , name in zip (self .metrics_task_ , names_task ):
207+ metrics [name + "_s" ] = metric (inputs_ys , task_s )
208+ metrics [name + "_t" ] = metric (inputs_yt , task_t )
209+ return metrics
210+
211+
212+ def predict (self , X ):
213+ """
214+ Predict method: return the prediction of task network
215+
216+ Parameters
217+ ----------
218+ X: array
219+ input data
220+
221+ Returns
222+ -------
223+ y_pred: array
224+ prediction of task network
225+ """
226+ X = check_one_array (X )
227+ return self .task_ .predict (X )
228+
229+
230+ def predict_weights (self , X ):
231+ """
232+ Return the predictions of weighting network
233+
234+ Parameters
235+ ----------
236+ X: array
237+ input data
238+
239+ Returns
240+ -------
241+ array:
242+ weights
243+ """
244+ return np .abs (self .weighter_ .predict (X ))
245+
246+
247+ def predict_disc (self , X ):
248+ """
249+ Return predictions of the discriminator.
250+
251+ Parameters
252+ ----------
253+ X : array
254+ input data
255+
256+ Returns
257+ -------
258+ y_disc : array
259+ predictions of discriminator network
260+ """
261+ X = check_one_array (X )
262+ return self .discriminator_ .predict (X )
0 commit comments