1+ """
2+ WDGRL
3+ """
4+
5+ import numpy as np
6+ import tensorflow as tf
7+ from tensorflow .keras import Model , Sequential
8+ from tensorflow .keras .layers import Layer , subtract
9+ from tensorflow .keras .optimizers import Adam
10+ import tensorflow .keras .backend as K
11+
12+ from adapt .utils import (GradientHandler ,
13+ check_arrays )
14+ from adapt .feature_based import BaseDeepFeature
15+
16+ EPS = K .epsilon ()
17+
18+
19+ class _Interpolation (Layer ):
20+ """
21+ Layer that produces interpolates points between
22+ two entries, with the distance of the interpolation
23+ to the first entry.
24+ """
25+
26+ def call (self , inputs ):
27+ Xs = inputs [0 ]
28+ Xt = inputs [1 ]
29+ batch_size = tf .shape (Xs )[0 ]
30+ dim = tf .shape (Xs )[1 :]
31+ alphas = tf .random .uniform ([batch_size ]+ [1 ]* len (dim ))
32+ tiled_shape = tf .concat (([1 ], dim ), 0 )
33+ tiled_alphas = tf .tile (alphas , tiled_shape )
34+ differences = Xt - Xs
35+ interpolates = Xs + tiled_alphas * differences
36+ distances = K .sqrt (K .mean (K .square (tiled_alphas * differences ),
37+ axis = [i for i in range (1 , len (dim ))]) + EPS )
38+ return interpolates , distances
39+
40+
41+ class WDGRL (BaseDeepFeature ):
42+ """
43+ WDGRL (Wasserstein Distance Guided Representation Learning) is an
44+ unsupervised domain adaptation method on the model of the
45+ :ref:`DANN <adapt.feature_based.DANN>`. In WDGRL the discriminator
46+ is used to approximate the Wasserstein distance between the
47+ source and target encoded distributions in the spirit of WGAN.
48+
49+ The optimization formulation is the following:
50+
51+ .. math::
52+
53+ \min_{\phi, F} & \; \mathcal{L}_{task}(F(\phi(X_S)), y_S) +
54+ \lambda \\ left(D(\phi(X_S)) - D(\phi(X_T)) \\ right) \\ \\
55+ \max_{D} & \; \\ left(D(\phi(X_S)) - D(\phi(X_T)) \\ right) -
56+ \\ gamma (||\\ nabla D(\\ alpha \phi(X_S) + (1- \\ alpha) \phi(X_T))||_2 - 1)^2
57+
58+ Where:
59+
60+ - :math:`(X_S, y_S), (X_T)` are respectively the labeled source data
61+ and the unlabeled target data.
62+ - :math:`\phi, F, D` are respectively the **encoder**, the **task**
63+ and the **discriminator** networks
64+ - :math:`\lambda` is the trade-off parameter.
65+ - :math:`\\ gamma` is the gradient penalty parameter.
66+
67+ .. figure:: ../_static/images/wdgrl.png
68+ :align: center
69+
70+ WDGRL architecture (source: [1])
71+
72+ Parameters
73+ ----------
74+ encoder : tensorflow Model (default=None)
75+ Encoder netwok. If ``None``, a shallow network with 10
76+ neurons and ReLU activation is used as encoder network.
77+
78+ task : tensorflow Model (default=None)
79+ Task netwok. If ``None``, a two layers network with 10
80+ neurons per layer and ReLU activation is used as task network.
81+
82+ discriminator : tensorflow Model (default=None)
83+ Discriminator netwok. If ``None``, a two layers network with 10
84+ neurons per layer and ReLU activation is used as discriminator
85+ network. Note that the output shape of the discriminator should
86+ be ``(None, 1)``.
87+
88+ lambda_ : float or None (default=1)
89+ Trade-off parameter. This parameter gives the trade-off
90+ for the encoder between learning the task and matching
91+ the source and target distribution. If `lambda_`is small
92+ the encoder will focus on the task. If `lambda_=0`, WDGRL
93+ is equivalent to a "source only" method.
94+
95+ gamma : float (default=1.)
96+ Gradient penalization parameter. To well approximate the
97+ Wasserstein, the `discriminator`should be 1-Lipschitz.
98+ This constraint is imposed by the gradient penalty term
99+ of the optimization. The good value `gamma` to use is
100+ not easy to find. One can check through the metrics that
101+ the gradient penalty term is in the same order than the
102+ "disc loss". If `gamma=0`, no penalty is given on the
103+ discriminator gradient.
104+
105+ loss : string or tensorflow loss (default="mse")
106+ Loss function used for the task.
107+
108+ metrics : dict or list of string or tensorflow metrics (default=None)
109+ Metrics given to the model. If a list is provided,
110+ metrics are used on both ``task`` and ``discriminator``
111+ outputs. To give seperated metrics, please provide a
112+ dict of metrics list with ``"task"`` and ``"disc"`` as keys.
113+
114+ optimizer : string or tensorflow optimizer (default=None)
115+ Optimizer of the model. If ``None``, the
116+ optimizer is set to tf.keras.optimizers.Adam(0.001)
117+
118+ copy : boolean (default=True)
119+ Whether to make a copy of ``encoder``, ``task`` and
120+ ``discriminator`` or not.
121+
122+ random_state : int (default=None)
123+ Seed of random generator.
124+
125+ Attributes
126+ ----------
127+ encoder_ : tensorflow Model
128+ encoder network.
129+
130+ task_ : tensorflow Model
131+ task network.
132+
133+ discriminator_ : tensorflow Model
134+ discriminator network.
135+
136+ model_ : tensorflow Model
137+ Fitted model: the union of ``encoder_``,
138+ ``task_`` and ``discriminator_`` networks.
139+
140+ history_ : dict
141+ history of the losses and metrics across the epochs.
142+ If ``yt`` is given in ``fit`` method, target metrics
143+ and losses are recorded too.
144+
145+ Examples
146+ --------
147+ >>> import numpy as np
148+ >>> from adapt.feature_based import WDGRL
149+ >>> np.random.seed(0)
150+ >>> Xs = np.concatenate((np.random.random((100, 1)),
151+ ... np.zeros((100, 1))), 1)
152+ >>> Xt = np.concatenate((np.random.random((100, 1)),
153+ ... np.ones((100, 1))), 1)
154+ >>> ys = 0.2 * Xs[:, 0]
155+ >>> yt = 0.2 * Xt[:, 0]
156+ >>> model = WDGRL(lambda_=0., random_state=0)
157+ >>> model.fit(Xs, ys, Xt, yt, epochs=100, verbose=0)
158+ >>> model.history_["task_t"][-1]
159+ 0.0223...
160+ >>> model = WDGRL(lambda_=1, random_state=0)
161+ >>> model.fit(Xs, ys, Xt, yt, epochs=100, verbose=0)
162+ >>> model.history_["task_t"][-1]
163+ 0.0044...
164+
165+ See also
166+ --------
167+ DANN
168+ ADDA
169+ DeepCORAL
170+
171+ References
172+ ----------
173+ .. [1] `[1] <https://arxiv.org/pdf/1707.01217.pdf>`_ Shen, J., Qu, Y., Zhang, W., \
174+ and Yu, Y. Wasserstein distance guided representation learning for domain adaptation. \
175+ In AAAI, 2018.
176+ """
177+ def __init__ (self ,
178+ encoder = None ,
179+ task = None ,
180+ discriminator = None ,
181+ lambda_ = 1. ,
182+ gamma = 1. ,
183+ loss = "mse" ,
184+ metrics = None ,
185+ optimizer = None ,
186+ copy = True ,
187+ random_state = None ):
188+
189+ self .lambda_ = lambda_
190+ self .gamma = gamma
191+ super ().__init__ (encoder , task , discriminator ,
192+ loss , metrics , optimizer , copy ,
193+ random_state )
194+
195+
196+ def create_model (self , inputs_Xs , inputs_Xt ):
197+
198+ encoded_src = self .encoder_ (inputs_Xs )
199+ encoded_tgt = self .encoder_ (inputs_Xt )
200+ task_src = self .task_ (encoded_src )
201+ task_tgt = self .task_ (encoded_tgt )
202+
203+ flip = GradientHandler (- self .lambda_ , name = "flip" )
204+ no_grad = GradientHandler (0 , name = "no_grad" )
205+
206+ disc_src = flip (encoded_src )
207+ disc_src = self .discriminator_ (disc_src )
208+ disc_tgt = flip (encoded_tgt )
209+ disc_tgt = self .discriminator_ (disc_tgt )
210+
211+ encoded_src_no_grad = no_grad (encoded_src )
212+ encoded_tgt_no_grad = no_grad (encoded_tgt )
213+
214+ interpolates , distances = _Interpolation ()([encoded_src_no_grad , encoded_tgt_no_grad ])
215+ disc_grad = K .abs (
216+ subtract ([self .discriminator_ (interpolates ), self .discriminator_ (encoded_src_no_grad )])
217+ )
218+ disc_grad /= distances
219+
220+ outputs = dict (task_src = task_src ,
221+ task_tgt = task_tgt ,
222+ disc_src = disc_src ,
223+ disc_tgt = disc_tgt ,
224+ disc_grad = disc_grad )
225+ return outputs
226+
227+
228+ def get_loss (self , inputs_ys ,
229+ task_src , task_tgt ,
230+ disc_src , disc_tgt ,
231+ disc_grad ):
232+
233+ loss_task = self .loss_ (inputs_ys , task_src )
234+ loss_disc = K .mean (disc_src ) - K .mean (disc_tgt )
235+ gradient_penalty = K .mean (K .square (disc_grad - 1. ))
236+
237+ loss = K .mean (loss_task ) - K .mean (loss_disc ) + self .gamma * K .mean (gradient_penalty )
238+ return loss
239+
240+
241+ def get_metrics (self , inputs_ys , inputs_yt ,
242+ task_src , task_tgt ,
243+ disc_src , disc_tgt , disc_grad ):
244+ metrics = {}
245+
246+ task_s = self .loss_ (inputs_ys , task_src )
247+ disc = K .mean (disc_src ) - K .mean (disc_tgt )
248+ grad_pen = K .square (disc_grad - 1. )
249+
250+ metrics ["task_s" ] = K .mean (task_s )
251+ metrics ["disc" ] = K .mean (disc )
252+ metrics ["grad_pen" ] = self .gamma * K .mean (grad_pen )
253+
254+ if inputs_yt is not None :
255+ task_t = self .loss_ (inputs_yt , task_tgt )
256+ metrics ["task_t" ] = K .mean (task_t )
257+
258+ names_task , names_disc = self ._get_metric_names ()
259+
260+ for metric , name in zip (self .metrics_task_ , names_task ):
261+ metrics [name + "_s" ] = metric (inputs_ys , task_src )
262+ if inputs_yt is not None :
263+ metrics [name + "_t" ] = metric (inputs_yt , task_tgt )
264+ return metrics
0 commit comments