@@ -28,6 +28,7 @@ class DenseCL(nn.Layer):
2828 Build a DenseCL model with: a query encoder, a key encoder, and a queue.
2929 https://arxiv.org/abs/2011.09157.
3030 """
31+
3132 def __init__ (self ,
3233 backbone ,
3334 neck = None ,
@@ -57,10 +58,10 @@ def __init__(self,
5758
5859 # create the encoders
5960 # num_classes is the output fc dimension
60- self .encoder_q = nn .Sequential (build_backbone ( backbone ),
61- build_neck (neck ))
62- self .encoder_k = nn .Sequential (build_backbone ( backbone ),
63- build_neck (neck ))
61+ self .encoder_q = nn .Sequential (
62+ build_backbone ( backbone ), build_neck (neck ))
63+ self .encoder_k = nn .Sequential (
64+ build_backbone ( backbone ), build_neck (neck ))
6465
6566 self .backbone = self .encoder_q [0 ]
6667 self .head = build_head (head )
@@ -187,6 +188,7 @@ def train_iter(self, *inputs, **kwargs):
187188 self ._momentum_update_key_encoder () # update the key encoder
188189
189190 # shuffle for making use of BN
191+ img_k = paddle .to_tensor (img_k )
190192 im_k , idx_unshuffle = self ._batch_shuffle_ddp (img_k )
191193
192194 k_b = self .encoder_k [0 ](im_k )
@@ -215,11 +217,12 @@ def train_iter(self, *inputs, **kwargs):
215217 backbone_sim_matrix = paddle .matmul (q_b .transpose ((0 , 2 , 1 )), k_b )
216218 densecl_sim_ind = backbone_sim_matrix .argmax (axis = 2 ) # NxS^2
217219
218- gather_index = densecl_sim_ind .unsqueeze (1 ).expand ((- 1 , k_grid .shape [1 ], - 1 ))
220+ gather_index = densecl_sim_ind .unsqueeze (1 ).expand (
221+ (- 1 , k_grid .shape [1 ], - 1 ))
219222 indexed_k_grid = paddle_gather (k_grid , dim = 2 , index = gather_index )
220223 densecl_sim_q = (q_grid * indexed_k_grid ).sum (1 ) # NxS^2
221224
222- l_pos_dense = densecl_sim_q .reshape ((- 1 , )).unsqueeze (- 1 ) # NS^2X1
225+ l_pos_dense = densecl_sim_q .reshape ((- 1 , )).unsqueeze (- 1 ) # NS^2X1
223226
224227 q_grid = q_grid .transpose ((0 , 2 , 1 ))
225228 q_grid = q_grid .reshape ((- 1 , q_grid .shape [2 ]))
@@ -229,7 +232,8 @@ def train_iter(self, *inputs, **kwargs):
229232 loss_dense = self .head (l_pos_dense , l_neg_dense )['loss' ]
230233
231234 outputs = dict ()
232- outputs ['loss' ] = loss_single * (1 - self .loss_lambda ) + loss_dense * self .loss_lambda
235+ outputs ['loss' ] = loss_single * (1 - self .loss_lambda
236+ ) + loss_dense * self .loss_lambda
233237
234238 self ._dequeue_and_enqueue (k )
235239 self ._dequeue_and_enqueue2 (k2 )
@@ -258,8 +262,11 @@ def paddle_gather(x, dim, index):
258262 else :
259263 reshape_shape = [1 ] * len (x .shape )
260264 reshape_shape [k ] = x .shape [k ]
261- dim_index = paddle .expand (paddle .reshape (paddle .arange (x .shape [k ], dtype = index .dtype ), reshape_shape ),
262- index_shape ).flatten ()
265+ dim_index = paddle .expand (
266+ paddle .reshape (
267+ paddle .arange (
268+ x .shape [k ], dtype = index .dtype ), reshape_shape ),
269+ index_shape ).flatten ()
263270 nd_index .append (dim_index )
264271 ind2 = paddle .transpose (paddle .stack (nd_index ), [1 , 0 ])
265272 paddle_out = paddle .gather_nd (x , ind2 ).reshape (index_shape )
0 commit comments