@@ -42,16 +42,26 @@ def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor,
4242@torch .jit .script
4343def ref_infonce (pos_dist : torch .Tensor , neg_dist : torch .Tensor ):
4444 with torch .no_grad ():
45- c , _ = neg_dist .max (dim = 1 )
45+ c , _ = neg_dist .max (dim = 1 , keepdim = True )
4646 c = c .detach ()
47- pos_dist = pos_dist - c
47+ pos_dist = pos_dist - c . squeeze ( 1 )
4848 neg_dist = neg_dist - c
4949
5050 align = (- pos_dist ).mean ()
5151 uniform = torch .logsumexp (neg_dist , dim = 1 ).mean ()
5252 return align + uniform , align , uniform
5353
5454
55+ @torch .jit .script
56+ def ref_infonce_not_stable (pos_dist : torch .Tensor , neg_dist : torch .Tensor ):
57+ pos_dist = pos_dist
58+ neg_dist = neg_dist
59+
60+ align = (- pos_dist ).mean ()
61+ uniform = torch .logsumexp (neg_dist , dim = 1 ).mean ()
62+ return align + uniform , align , uniform
63+
64+
5565class ReferenceInfoNCE (nn .Module ):
5666 """The InfoNCE loss.
5767 Attributes:
@@ -208,3 +218,112 @@ def test_infonce_reference_new_equivalence(temperature):
208218def test_alias ():
209219 assert cebra_criterions .InfoNCE == cebra_criterions .FixedCosineInfoNCE
210220 assert cebra_criterions .InfoMSE == cebra_criterions .FixedEuclideanInfoNCE
221+
222+
223+ def _reference_dot_similarity (ref , pos , neg ):
224+ pos_dist = torch .zeros (ref .shape [0 ])
225+ neg_dist = torch .zeros (ref .shape [0 ], neg .shape [0 ])
226+ for d in range (ref .shape [1 ]):
227+ for i in range (len (ref )):
228+ pos_dist [i ] += ref [i , d ] * pos [i , d ]
229+ for j in range (len (neg )):
230+ neg_dist [i , j ] += ref [i , d ] * neg [j , d ]
231+ return pos_dist , neg_dist
232+
233+
234+ def _reference_euclidean_similarity (ref , pos , neg ):
235+ pos_dist = torch .zeros (ref .shape [0 ])
236+ neg_dist = torch .zeros (ref .shape [0 ], neg .shape [0 ])
237+ for d in range (ref .shape [1 ]):
238+ for i in range (len (ref )):
239+ pos_dist [i ] += - (ref [i , d ] - pos [i , d ])** 2
240+ for j in range (len (neg )):
241+ neg_dist [i , j ] += - (ref [i , d ] - neg [j , d ])** 2
242+ return pos_dist , neg_dist
243+
244+
245+ def _reference_infonce (pos_dist , neg_dist ):
246+ align = - pos_dist .mean ()
247+ uniform = torch .logsumexp (neg_dist , dim = 1 ).mean ()
248+ return align + uniform , align , uniform
249+
250+
251+ def test_similiarities ():
252+
253+ ref = torch .randn (10 , 3 )
254+ pos = torch .randn (10 , 3 )
255+ neg = torch .randn (12 , 3 )
256+
257+ pos_dist , neg_dist = _reference_dot_similarity (ref , pos , neg )
258+ pos_dist_2 , neg_dist_2 = cebra_criterions .dot_similarity (ref , pos , neg )
259+
260+ assert torch .allclose (pos_dist , pos_dist_2 )
261+ assert torch .allclose (neg_dist , neg_dist_2 )
262+
263+ pos_dist , neg_dist = _reference_euclidean_similarity (ref , pos , neg )
264+ pos_dist_2 , neg_dist_2 = cebra_criterions .euclidean_similarity (
265+ ref , pos , neg )
266+
267+ assert torch .allclose (pos_dist , pos_dist_2 )
268+ assert torch .allclose (neg_dist , neg_dist_2 )
269+
270+
271+ def _compute_grads (output , inputs ):
272+ for input_ in inputs :
273+ input_ .grad = None
274+ assert input_ .requires_grad
275+ output .backward ()
276+ return [input_ .grad for input_ in inputs ]
277+
278+
279+ def test_infonce ():
280+
281+ pos_dist = torch .randn (100 ,)
282+ neg_dist = torch .randn (100 , 100 )
283+
284+ ref_loss , ref_align , ref_uniform = _reference_infonce (pos_dist , neg_dist )
285+ loss , align , uniform = cebra_criterions .infonce (pos_dist , neg_dist )
286+
287+ assert torch .allclose (ref_loss , loss )
288+ assert torch .allclose (ref_align , align , atol = 0.0001 )
289+ assert torch .allclose (ref_uniform , uniform )
290+ assert torch .allclose (align + uniform , loss )
291+
292+
293+ def test_infonce_gradients ():
294+
295+ rng = torch .Generator ().manual_seed (42 )
296+ pos_dist = torch .randn (100 , generator = rng )
297+ neg_dist = torch .randn (100 , 100 , generator = rng )
298+
299+ for i in range (3 ):
300+ pos_dist_ = pos_dist .clone ()
301+ neg_dist_ = neg_dist .clone ()
302+ pos_dist_ .requires_grad_ (True )
303+ neg_dist_ .requires_grad_ (True )
304+ loss_ref = _reference_infonce (pos_dist_ , neg_dist_ )[i ]
305+ grad_ref = _compute_grads (loss_ref , [pos_dist_ , neg_dist_ ])
306+
307+ pos_dist_ = pos_dist .clone ()
308+ neg_dist_ = neg_dist .clone ()
309+ pos_dist_ .requires_grad_ (True )
310+ neg_dist_ .requires_grad_ (True )
311+ loss = cebra_criterions .infonce (pos_dist_ , neg_dist_ )[i ]
312+ grad = _compute_grads (loss , [pos_dist_ , neg_dist_ ])
313+
314+ # NOTE(stes) default relative tolerance is 1e-5
315+ assert torch .allclose (loss_ref , loss , rtol = 1e-4 )
316+
317+ if i == 0 :
318+ assert grad [0 ] is not None
319+ assert grad [1 ] is not None
320+ assert torch .allclose (grad_ref [0 ], grad [0 ])
321+ assert torch .allclose (grad_ref [1 ], grad [1 ])
322+ if i == 1 :
323+ assert grad [0 ] is not None
324+ assert grad [1 ] is None
325+ assert torch .allclose (grad_ref [0 ], grad [0 ])
326+ if i == 2 :
327+ assert grad [0 ] is None
328+ assert grad [1 ] is not None
329+ assert torch .allclose (grad_ref [1 ], grad [1 ])
0 commit comments