Skip to content

Commit e0fbac6

Browse files
committed
consistent space
1 parent cbd302c commit e0fbac6

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/utils/losses.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(self, device, batch_size, pos_collected_numerator):
117117
def _calculate_similarity_matrix(self):
118118
return self._cosine_simililarity_matrix
119119

120+
120121
def remove_diag(self, M):
121122
h, w = M.shape
122123
assert h==w, "h and w should be same"
@@ -125,10 +126,12 @@ def remove_diag(self, M):
125126
mask = (mask).type(torch.bool).to(self.device)
126127
return M[mask].view(h, -1)
127128

129+
128130
def _cosine_simililarity_matrix(self, x, y):
129131
v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
130132
return v
131133

134+
132135
def forward(self, inst_embed, proxy, negative_mask, labels, temperature, margin):
133136
similarity_matrix = self.calculate_similarity_matrix(inst_embed, inst_embed)
134137
instance_zone = torch.exp((self.remove_diag(similarity_matrix) - margin)/temperature)
@@ -161,6 +164,7 @@ def __init__(self, device, batch_size, pos_collected_numerator):
161164
def _calculate_similarity_matrix(self):
162165
return self._cosine_simililarity_matrix
163166

167+
164168
def remove_diag(self, M):
165169
h, w = M.shape
166170
assert h==w, "h and w should be same"
@@ -169,10 +173,12 @@ def remove_diag(self, M):
169173
mask = (mask).type(torch.bool).to(self.device)
170174
return M[mask].view(h, -1)
171175

176+
172177
def _cosine_simililarity_matrix(self, x, y):
173178
v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
174179
return v
175180

181+
176182
def forward(self, inst_embed, proxy, negative_mask, labels, temperature, margin):
177183
p2i_similarity_matrix = self.calculate_similarity_matrix(proxy, inst_embed)
178184
i2i_similarity_matrix = self.calculate_similarity_matrix(inst_embed, inst_embed)
@@ -202,13 +208,15 @@ def __init__(self, device, embedding_layer, num_classes, batch_size):
202208
self.batch_size = batch_size
203209
self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
204210

211+
205212
def _get_positive_proxy_mask(self, labels):
206213
labels = labels.detach().cpu().numpy()
207214
rvs_one_hot_target = np.ones([self.num_classes, self.num_classes]) - np.eye(self.num_classes)
208215
rvs_one_hot_target = rvs_one_hot_target[labels]
209216
mask = torch.from_numpy((rvs_one_hot_target)).type(torch.bool)
210217
return mask.to(self.device)
211218

219+
212220
def forward(self, inst_embed, proxy, labels):
213221
all_labels = torch.tensor([c for c in range(self.num_classes)]).type(torch.long).to(self.device)
214222
positive_proxy_mask = self._get_positive_proxy_mask(labels)
@@ -231,13 +239,15 @@ def __init__(self, device, batch_size, use_cosine_similarity=True):
231239
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
232240
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
233241

242+
234243
def _get_similarity_function(self, use_cosine_similarity):
235244
if use_cosine_similarity:
236245
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
237246
return self._cosine_simililarity
238247
else:
239248
return self._dot_simililarity
240249

250+
241251
def _get_correlated_mask(self):
242252
diag = np.eye(2 * self.batch_size)
243253
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
@@ -246,6 +256,7 @@ def _get_correlated_mask(self):
246256
mask = (1 - mask).type(torch.bool)
247257
return mask.to(self.device)
248258

259+
249260
@staticmethod
250261
def _dot_simililarity(x, y):
251262
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
@@ -254,13 +265,15 @@ def _dot_simililarity(x, y):
254265
# v shape: (N, 2N)
255266
return v
256267

268+
257269
def _cosine_simililarity(self, x, y):
258270
# x shape: (N, 1, C)
259271
# y shape: (1, 2N, C)
260272
# v shape: (N, 2N)
261273
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
262274
return v
263275

276+
264277
def forward(self, zis, zjs, temperature):
265278
representations = torch.cat([zjs, zis], dim=0)
266279

0 commit comments

Comments
 (0)