@@ -15,7 +15,7 @@ def pairwise_y(X, Y):
1515 Y = tf .reshape (Y , (batch_size_y , dim ))
1616 X = tf .tile (tf .expand_dims (X , - 1 ), [1 , 1 , batch_size_y ])
1717 Y = tf .tile (tf .expand_dims (Y , - 1 ), [1 , 1 , batch_size_x ])
18- return tf .reduce_sum (tf .abs (X - tf .transpose (Y )), 1 )/ 2
18+ return tf .reduce_sum (tf .abs (X - tf .transpose (Y )), 1 )/ 2.
1919
2020
2121def pairwise_X (X , Y ):
@@ -88,6 +88,22 @@ class CCSA(BaseAdaptDeep):
8888 If ``yt`` is given in ``fit`` method, target metrics
8989 and losses are recorded too.
9090
91+ See also
92+ --------
93+ CDAN
94+
95+ Examples
96+ --------
97+ >>> import numpy as np
98+ >>> from adapt.utils import make_classification_da
99+ >>> from adapt.feature_based import CCSA
100+ >>> Xs, ys, Xt, yt = make_classification_da()
101+ >>> model = CCSA(margin=1., gamma=0.5, Xt=Xt, metrics=["acc"], random_state=0)
102+ >>> model.fit(Xs, ys, epochs=100, verbose=0)
103+ >>> model.score(Xt, yt)
104+ 1/1 [==============================] - 0s 180ms/step - loss: 0.1550 - acc: 0.8900
105+ 0.15503168106079102
106+
91107 References
92108 ----------
93109 .. [1] `[1] <https://arxiv.org/abs/1709.10190>`_ S. Motiian, M. Piccirilli, \
@@ -116,18 +132,36 @@ def __init__(self,
116132 def train_step (self , data ):
117133 # Unpack the data.
118134 Xs , Xt , ys , yt = self ._unpack_data (data )
119-
135+
136+ # Check that yt is not None
137+ if yt is None :
138+ raise ValueError ("The target labels `yt` is `None`, CCSA is a supervised"
139+ " domain adaptation method and need `yt` to be specified." )
140+
141+ # Check shape of ys
142+ if len (ys .get_shape ()) <= 1 or ys .get_shape ()[1 ] == 1 :
143+ self ._ys_is_1d = True
144+ else :
145+ self ._ys_is_1d = False
146+
120147 # loss
121- with tf .GradientTape () as task_tape , tf .GradientTape () as enc_tape :
148+ with tf .GradientTape () as task_tape , tf .GradientTape () as enc_tape :
122149 # Forward pass
123150 Xs_enc = self .encoder_ (Xs , training = True )
124151 ys_pred = self .task_ (Xs_enc , training = True )
125152
153+ # Change type
154+ ys = tf .cast (ys , ys_pred .dtype )
155+ yt = tf .cast (yt , ys_pred .dtype )
156+
126157 Xt_enc = self .encoder_ (Xt , training = True )
127158
128- dist_y = pairwise_y (ys , yt )
159+ dist_y = pairwise_y (ys , yt )
129160 dist_X = pairwise_X (Xs_enc , Xt_enc )
130161
162+ if self ._ys_is_1d :
163+ dist_y *= 2.
164+
131165 contrastive_loss = tf .reduce_sum (dist_y * tf .maximum (0. , self .margin - dist_X ), 1 ) / (tf .reduce_sum (dist_y , 1 ) + EPS )
132166 contrastive_loss += tf .reduce_sum ((1 - dist_y ) * dist_X , 1 ) / (tf .reduce_sum (1 - dist_y , 1 ) + EPS )
133167 contrastive_loss = tf .reduce_mean (contrastive_loss )
0 commit comments