Skip to content

Commit 8520b9b

Browse files
stesMMathisLab
andauthored
Fix broadcasting in InfoNCE loss (#86)
* Fix broadcasting in numerically stabilized InfoNCE * Update version * Update criterions.py - seemed redundant. * Fix typo * Revert version change * Seed the criterion tests --------- Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent 5537753 commit 8520b9b

File tree

2 files changed

+134
-5
lines changed

2 files changed

+134
-5
lines changed

cebra/models/criterions.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,25 @@ def infonce(
8181
"""InfoNCE implementation
8282
8383
See :py:class:`BaseInfoNCE` for reference.
84+
85+
Note:
86+
- The behavior of this function changed beginning in CEBRA 0.3.0.
87+
The InfoNCE implementation is numerically stabilized.
8488
"""
8589
with torch.no_grad():
86-
c, _ = neg_dist.max(dim=1)
90+
c, _ = neg_dist.max(dim=1, keepdim=True)
8791
c = c.detach()
88-
pos_dist = pos_dist - c
92+
93+
pos_dist = pos_dist - c.squeeze(1)
8994
neg_dist = neg_dist - c
9095
align = (-pos_dist).mean()
9196
uniform = torch.logsumexp(neg_dist, dim=1).mean()
92-
return align + uniform, align, uniform
97+
98+
c_mean = c.mean()
99+
align_corrected = align - c_mean
100+
uniform_corrected = uniform + c_mean
101+
102+
return align + uniform, align_corrected, uniform_corrected
93103

94104

95105
class ContrastiveLoss(nn.Module):

tests/test_criterions.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,26 @@ def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor,
4242
@torch.jit.script
4343
def ref_infonce(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
4444
with torch.no_grad():
45-
c, _ = neg_dist.max(dim=1)
45+
c, _ = neg_dist.max(dim=1, keepdim=True)
4646
c = c.detach()
47-
pos_dist = pos_dist - c
47+
pos_dist = pos_dist - c.squeeze(1)
4848
neg_dist = neg_dist - c
4949

5050
align = (-pos_dist).mean()
5151
uniform = torch.logsumexp(neg_dist, dim=1).mean()
5252
return align + uniform, align, uniform
5353

5454

55+
@torch.jit.script
56+
def ref_infonce_not_stable(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
57+
pos_dist = pos_dist
58+
neg_dist = neg_dist
59+
60+
align = (-pos_dist).mean()
61+
uniform = torch.logsumexp(neg_dist, dim=1).mean()
62+
return align + uniform, align, uniform
63+
64+
5565
class ReferenceInfoNCE(nn.Module):
5666
"""The InfoNCE loss.
5767
Attributes:
@@ -208,3 +218,112 @@ def test_infonce_reference_new_equivalence(temperature):
208218
def test_alias():
209219
assert cebra_criterions.InfoNCE == cebra_criterions.FixedCosineInfoNCE
210220
assert cebra_criterions.InfoMSE == cebra_criterions.FixedEuclideanInfoNCE
221+
222+
223+
def _reference_dot_similarity(ref, pos, neg):
224+
pos_dist = torch.zeros(ref.shape[0])
225+
neg_dist = torch.zeros(ref.shape[0], neg.shape[0])
226+
for d in range(ref.shape[1]):
227+
for i in range(len(ref)):
228+
pos_dist[i] += ref[i, d] * pos[i, d]
229+
for j in range(len(neg)):
230+
neg_dist[i, j] += ref[i, d] * neg[j, d]
231+
return pos_dist, neg_dist
232+
233+
234+
def _reference_euclidean_similarity(ref, pos, neg):
235+
pos_dist = torch.zeros(ref.shape[0])
236+
neg_dist = torch.zeros(ref.shape[0], neg.shape[0])
237+
for d in range(ref.shape[1]):
238+
for i in range(len(ref)):
239+
pos_dist[i] += -(ref[i, d] - pos[i, d])**2
240+
for j in range(len(neg)):
241+
neg_dist[i, j] += -(ref[i, d] - neg[j, d])**2
242+
return pos_dist, neg_dist
243+
244+
245+
def _reference_infonce(pos_dist, neg_dist):
246+
align = -pos_dist.mean()
247+
uniform = torch.logsumexp(neg_dist, dim=1).mean()
248+
return align + uniform, align, uniform
249+
250+
251+
def test_similiarities():
252+
253+
ref = torch.randn(10, 3)
254+
pos = torch.randn(10, 3)
255+
neg = torch.randn(12, 3)
256+
257+
pos_dist, neg_dist = _reference_dot_similarity(ref, pos, neg)
258+
pos_dist_2, neg_dist_2 = cebra_criterions.dot_similarity(ref, pos, neg)
259+
260+
assert torch.allclose(pos_dist, pos_dist_2)
261+
assert torch.allclose(neg_dist, neg_dist_2)
262+
263+
pos_dist, neg_dist = _reference_euclidean_similarity(ref, pos, neg)
264+
pos_dist_2, neg_dist_2 = cebra_criterions.euclidean_similarity(
265+
ref, pos, neg)
266+
267+
assert torch.allclose(pos_dist, pos_dist_2)
268+
assert torch.allclose(neg_dist, neg_dist_2)
269+
270+
271+
def _compute_grads(output, inputs):
272+
for input_ in inputs:
273+
input_.grad = None
274+
assert input_.requires_grad
275+
output.backward()
276+
return [input_.grad for input_ in inputs]
277+
278+
279+
def test_infonce():
280+
281+
pos_dist = torch.randn(100,)
282+
neg_dist = torch.randn(100, 100)
283+
284+
ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)
285+
loss, align, uniform = cebra_criterions.infonce(pos_dist, neg_dist)
286+
287+
assert torch.allclose(ref_loss, loss)
288+
assert torch.allclose(ref_align, align, atol=0.0001)
289+
assert torch.allclose(ref_uniform, uniform)
290+
assert torch.allclose(align + uniform, loss)
291+
292+
293+
def test_infonce_gradients():
294+
295+
rng = torch.Generator().manual_seed(42)
296+
pos_dist = torch.randn(100, generator=rng)
297+
neg_dist = torch.randn(100, 100, generator=rng)
298+
299+
for i in range(3):
300+
pos_dist_ = pos_dist.clone()
301+
neg_dist_ = neg_dist.clone()
302+
pos_dist_.requires_grad_(True)
303+
neg_dist_.requires_grad_(True)
304+
loss_ref = _reference_infonce(pos_dist_, neg_dist_)[i]
305+
grad_ref = _compute_grads(loss_ref, [pos_dist_, neg_dist_])
306+
307+
pos_dist_ = pos_dist.clone()
308+
neg_dist_ = neg_dist.clone()
309+
pos_dist_.requires_grad_(True)
310+
neg_dist_.requires_grad_(True)
311+
loss = cebra_criterions.infonce(pos_dist_, neg_dist_)[i]
312+
grad = _compute_grads(loss, [pos_dist_, neg_dist_])
313+
314+
# NOTE(stes) default relative tolerance is 1e-5
315+
assert torch.allclose(loss_ref, loss, rtol = 1e-4)
316+
317+
if i == 0:
318+
assert grad[0] is not None
319+
assert grad[1] is not None
320+
assert torch.allclose(grad_ref[0], grad[0])
321+
assert torch.allclose(grad_ref[1], grad[1])
322+
if i == 1:
323+
assert grad[0] is not None
324+
assert grad[1] is None
325+
assert torch.allclose(grad_ref[0], grad[0])
326+
if i == 2:
327+
assert grad[0] is None
328+
assert grad[1] is not None
329+
assert torch.allclose(grad_ref[1], grad[1])

0 commit comments

Comments
 (0)