@@ -117,6 +117,7 @@ def __init__(self, device, batch_size, pos_collected_numerator):
117117 def _calculate_similarity_matrix (self ):
118118 return self ._cosine_simililarity_matrix
119119
120+
120121 def remove_diag (self , M ):
121122 h , w = M .shape
122123 assert h == w , "h and w should be same"
@@ -125,10 +126,12 @@ def remove_diag(self, M):
125126 mask = (mask ).type (torch .bool ).to (self .device )
126127 return M [mask ].view (h , - 1 )
127128
129+
128130 def _cosine_simililarity_matrix (self , x , y ):
129131 v = self .cosine_similarity (x .unsqueeze (1 ), y .unsqueeze (0 ))
130132 return v
131133
134+
132135 def forward (self , inst_embed , proxy , negative_mask , labels , temperature , margin ):
133136 similarity_matrix = self .calculate_similarity_matrix (inst_embed , inst_embed )
134137 instance_zone = torch .exp ((self .remove_diag (similarity_matrix ) - margin )/ temperature )
@@ -161,6 +164,7 @@ def __init__(self, device, batch_size, pos_collected_numerator):
161164 def _calculate_similarity_matrix (self ):
162165 return self ._cosine_simililarity_matrix
163166
167+
164168 def remove_diag (self , M ):
165169 h , w = M .shape
166170 assert h == w , "h and w should be same"
@@ -169,10 +173,12 @@ def remove_diag(self, M):
169173 mask = (mask ).type (torch .bool ).to (self .device )
170174 return M [mask ].view (h , - 1 )
171175
176+
172177 def _cosine_simililarity_matrix (self , x , y ):
173178 v = self .cosine_similarity (x .unsqueeze (1 ), y .unsqueeze (0 ))
174179 return v
175180
181+
176182 def forward (self , inst_embed , proxy , negative_mask , labels , temperature , margin ):
177183 p2i_similarity_matrix = self .calculate_similarity_matrix (proxy , inst_embed )
178184 i2i_similarity_matrix = self .calculate_similarity_matrix (inst_embed , inst_embed )
@@ -202,13 +208,15 @@ def __init__(self, device, embedding_layer, num_classes, batch_size):
202208 self .batch_size = batch_size
203209 self .cosine_similarity = torch .nn .CosineSimilarity (dim = - 1 )
204210
211+
205212 def _get_positive_proxy_mask (self , labels ):
206213 labels = labels .detach ().cpu ().numpy ()
207214 rvs_one_hot_target = np .ones ([self .num_classes , self .num_classes ]) - np .eye (self .num_classes )
208215 rvs_one_hot_target = rvs_one_hot_target [labels ]
209216 mask = torch .from_numpy ((rvs_one_hot_target )).type (torch .bool )
210217 return mask .to (self .device )
211218
219+
212220 def forward (self , inst_embed , proxy , labels ):
213221 all_labels = torch .tensor ([c for c in range (self .num_classes )]).type (torch .long ).to (self .device )
214222 positive_proxy_mask = self ._get_positive_proxy_mask (labels )
@@ -231,13 +239,15 @@ def __init__(self, device, batch_size, use_cosine_similarity=True):
231239 self .similarity_function = self ._get_similarity_function (use_cosine_similarity )
232240 self .criterion = torch .nn .CrossEntropyLoss (reduction = "sum" )
233241
242+
234243 def _get_similarity_function (self , use_cosine_similarity ):
235244 if use_cosine_similarity :
236245 self ._cosine_similarity = torch .nn .CosineSimilarity (dim = - 1 )
237246 return self ._cosine_simililarity
238247 else :
239248 return self ._dot_simililarity
240249
250+
241251 def _get_correlated_mask (self ):
242252 diag = np .eye (2 * self .batch_size )
243253 l1 = np .eye ((2 * self .batch_size ), 2 * self .batch_size , k = - self .batch_size )
@@ -246,6 +256,7 @@ def _get_correlated_mask(self):
246256 mask = (1 - mask ).type (torch .bool )
247257 return mask .to (self .device )
248258
259+
249260 @staticmethod
250261 def _dot_simililarity (x , y ):
251262 v = torch .tensordot (x .unsqueeze (1 ), y .T .unsqueeze (0 ), dims = 2 )
@@ -254,13 +265,15 @@ def _dot_simililarity(x, y):
254265 # v shape: (N, 2N)
255266 return v
256267
268+
257269 def _cosine_simililarity (self , x , y ):
258270 # x shape: (N, 1, C)
259271 # y shape: (1, 2N, C)
260272 # v shape: (N, 2N)
261273 v = self ._cosine_similarity (x .unsqueeze (1 ), y .unsqueeze (0 ))
262274 return v
263275
276+
264277 def forward (self , zis , zjs , temperature ):
265278 representations = torch .cat ([zjs , zis ], dim = 0 )
266279
0 commit comments