Skip to content

Commit cbef352

Browse files
authored
All gather (#128)
* fix all gather bug
1 parent 83c49e6 commit cbef352

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

passl/modeling/architectures/densecl.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class DenseCL(nn.Layer):
2828
Build a DenseCL model with: a query encoder, a key encoder, and a queue.
2929
https://arxiv.org/abs/2011.09157.
3030
"""
31+
3132
def __init__(self,
3233
backbone,
3334
neck=None,
@@ -57,10 +58,10 @@ def __init__(self,
5758

5859
# create the encoders
5960
# num_classes is the output fc dimension
60-
self.encoder_q = nn.Sequential(build_backbone(backbone),
61-
build_neck(neck))
62-
self.encoder_k = nn.Sequential(build_backbone(backbone),
63-
build_neck(neck))
61+
self.encoder_q = nn.Sequential(
62+
build_backbone(backbone), build_neck(neck))
63+
self.encoder_k = nn.Sequential(
64+
build_backbone(backbone), build_neck(neck))
6465

6566
self.backbone = self.encoder_q[0]
6667
self.head = build_head(head)
@@ -187,6 +188,7 @@ def train_iter(self, *inputs, **kwargs):
187188
self._momentum_update_key_encoder() # update the key encoder
188189

189190
# shuffle for making use of BN
191+
img_k = paddle.to_tensor(img_k)
190192
im_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)
191193

192194
k_b = self.encoder_k[0](im_k)
@@ -215,11 +217,12 @@ def train_iter(self, *inputs, **kwargs):
215217
backbone_sim_matrix = paddle.matmul(q_b.transpose((0, 2, 1)), k_b)
216218
densecl_sim_ind = backbone_sim_matrix.argmax(axis=2) # NxS^2
217219

218-
gather_index = densecl_sim_ind.unsqueeze(1).expand((-1, k_grid.shape[1], -1))
220+
gather_index = densecl_sim_ind.unsqueeze(1).expand(
221+
(-1, k_grid.shape[1], -1))
219222
indexed_k_grid = paddle_gather(k_grid, dim=2, index=gather_index)
220223
densecl_sim_q = (q_grid * indexed_k_grid).sum(1) # NxS^2
221224

222-
l_pos_dense = densecl_sim_q.reshape((-1, )).unsqueeze(-1) # NS^2X1
225+
l_pos_dense = densecl_sim_q.reshape((-1, )).unsqueeze(-1) # NS^2X1
223226

224227
q_grid = q_grid.transpose((0, 2, 1))
225228
q_grid = q_grid.reshape((-1, q_grid.shape[2]))
@@ -229,7 +232,8 @@ def train_iter(self, *inputs, **kwargs):
229232
loss_dense = self.head(l_pos_dense, l_neg_dense)['loss']
230233

231234
outputs = dict()
232-
outputs['loss'] = loss_single * (1 - self.loss_lambda) + loss_dense * self.loss_lambda
235+
outputs['loss'] = loss_single * (1 - self.loss_lambda
236+
) + loss_dense * self.loss_lambda
233237

234238
self._dequeue_and_enqueue(k)
235239
self._dequeue_and_enqueue2(k2)
@@ -258,8 +262,11 @@ def paddle_gather(x, dim, index):
258262
else:
259263
reshape_shape = [1] * len(x.shape)
260264
reshape_shape[k] = x.shape[k]
261-
dim_index = paddle.expand(paddle.reshape(paddle.arange(x.shape[k], dtype=index.dtype), reshape_shape),
262-
index_shape).flatten()
265+
dim_index = paddle.expand(
266+
paddle.reshape(
267+
paddle.arange(
268+
x.shape[k], dtype=index.dtype), reshape_shape),
269+
index_shape).flatten()
263270
nd_index.append(dim_index)
264271
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0])
265272
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)

passl/modeling/architectures/moco.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class MoCo(nn.Layer):
2929
Build a MoCo model with: a query encoder, a key encoder, and a queue
3030
https://arxiv.org/abs/1911.05722
3131
"""
32+
3233
def __init__(self,
3334
backbone,
3435
neck=None,
@@ -56,10 +57,10 @@ def __init__(self,
5657

5758
# create the encoders
5859
# num_classes is the output fc dimension
59-
self.encoder_q = nn.Sequential(build_backbone(backbone),
60-
build_neck(neck))
61-
self.encoder_k = nn.Sequential(build_backbone(backbone),
62-
build_neck(neck))
60+
self.encoder_q = nn.Sequential(
61+
build_backbone(backbone), build_neck(neck))
62+
self.encoder_k = nn.Sequential(
63+
build_backbone(backbone), build_neck(neck))
6364

6465
self.backbone = self.encoder_q[0]
6566

@@ -162,6 +163,7 @@ def train_iter(self, *inputs, **kwargs):
162163
self._momentum_update_key_encoder() # update the key encoder
163164

164165
# shuffle for making use of BN
166+
img_k = paddle.to_tensor(img_k)
165167
im_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)
166168

167169
k = self.encoder_k(im_k) # keys: NxC
@@ -205,4 +207,4 @@ def concat_all_gather(tensor):
205207
paddle.distributed.all_gather(tensors_gather, tensor)
206208

207209
output = paddle.concat(tensors_gather, axis=0)
208-
return output
210+
return output

0 commit comments

Comments
 (0)