Skip to content

Commit c6b6c6d

Browse files
committed
add cosine gamma schedule
1 parent 3ae88f1 commit c6b6c6d

File tree

1 file changed

+95
-19
lines changed

1 file changed

+95
-19
lines changed

libauc/losses/contrastive.py

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class GCLoss_v1(nn.Module):
4040
tau_min (float, optional): lower bound of learnable temperature in iSogCLR (default: ``0.05``)
4141
tau_max (float, optional): upper bound of learnable temperature in iSogCLR (default: ``0.7``)
4242
beta (float, optional): the momentum parameter for updating temperature parameters in iSogCLR (default: ``0.9``)
43+
gamma (float, optional): the moving average factor for dynamic loss in range the range of (0.0, 1.0) (default: ``0.9``)
44+
gamma_schedule (str, optional): the schedule for gamma. Options are 'constant' (fixed ``gamma``) and 'cosine' (decaying from 1.0 to ``gamma``) (default: ``'cosine'``)
45+
gamma_decay_epochs (int, optional): After this number of epochs, gamma will decrease to the value set by the option ``gamma``. Used only when gamma_schedule is 'cosine'. We recommend a value of total_training_epochs // 2 (default: ``-1``)
4346
4447
Example:
4548
>>> loss_fn = GCLoss_v1(N=1000, tau=0.1)
@@ -65,7 +68,10 @@ def __init__(self,
6568
device=None,
6669
distributed=False,
6770
enable_isogclr=False,
68-
tau_min=0.05, tau_max=0.7, rho=0.3, eta=0.01, beta=0.9):
71+
tau_min=0.05, tau_max=0.7, rho=0.3, eta=0.01, beta=0.9,
72+
gamma_schedule='constant',
73+
gamma_decay_epochs=-1,
74+
):
6975
super(GCLoss_v1, self).__init__()
7076
if not device:
7177
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -74,7 +80,8 @@ def __init__(self,
7480
self.N = N
7581
self.u = torch.zeros(N).reshape(-1, 1) #.to(self.device)
7682
self.tau = tau
77-
self.gamma = gamma
83+
self.gamma_min = gamma
84+
self.gamma = 1.0
7885
self.distributed = distributed
7986
self.LARGE_NUM = 1e9
8087
self.eps = eps
@@ -89,6 +96,32 @@ def __init__(self,
8996
self.learnable_tau = torch.ones(N).reshape(-1, 1) * self.tau
9097
self.grad_tau = torch.zeros(N).reshape(-1, 1)
9198

99+
self.gamma_schedule = gamma_schedule
100+
assert self.gamma_schedule in ["constant", "cosine"]
101+
if self.gamma_schedule == "cosine":
102+
assert gamma_decay_epochs > 0
103+
self.gamma_decay_epochs = gamma_decay_epochs
104+
gamma_str = f"Using {self.gamma_schedule} schedule for gamma"
105+
if self.gamma_schedule == "constant":
106+
gamma_str += f" with gamma = {self.gamma_min}"
107+
else:
108+
gamma_str += f" with gamma_min = {self.gamma_min}, gamma_decay_epochs = {self.gamma_decay_epochs}"
109+
print(gamma_str)
110+
111+
def adjust_gamma(self, epoch: int):
112+
"""Adjust gamma for dynamic loss according to its schedule."""
113+
if self.gamma_schedule == "constant":
114+
if epoch == 0:
115+
self.gamma = 1.0
116+
else:
117+
self.gamma = self.gamma_min
118+
elif self.gamma_schedule == "cosine":
119+
if epoch < self.gamma_decay_epochs:
120+
self.gamma = (1- self.gamma_min) * 0.5 * (1 + np.cos(np.pi * epoch / self.gamma_decay_epochs)) + self.gamma_min
121+
else:
122+
self.gamma = self.gamma_min
123+
print(f"Epoch: {epoch}, gamma: {self.gamma:.3f}")
124+
92125
def forward(self,
93126
hidden1,
94127
hidden2,
@@ -133,13 +166,15 @@ def forward(self,
133166
neg_logits1 = torch.exp(logits_ab_aa/self.tau)*neg_mask #(B, 2B)
134167
neg_logits2 = torch.exp(logits_ba_bb/self.tau)*neg_mask
135168

136-
# u init
137-
if self.u[index].sum() == 0:
138-
u1 = torch.sum(neg_logits1, dim=1, keepdim=True)/(2*(batch_size-1))
139-
u2 = torch.sum(neg_logits2, dim=1, keepdim=True)/(2*(batch_size-1))
169+
if self.gamma_schedule == "constant" and self.gamma > self.gamma_min:
170+
if self.u[index].sum() == 0:
171+
gamma = 1.0
172+
else:
173+
gamma = self.gamma_min
140174
else:
141-
u1 = (1 - self.gamma ) * self.u[index].cuda() + self.gamma * torch.sum(neg_logits1, dim=1, keepdim=True)/(2*(batch_size-1))
142-
u2 = (1 - self.gamma ) * self.u[index].cuda() + self.gamma * torch.sum(neg_logits2, dim=1, keepdim=True)/(2*(batch_size-1))
175+
gamma = self.gamma
176+
u1 = (1 - gamma) * self.u[index].cuda() + gamma * torch.sum(neg_logits1, dim=1, keepdim=True)/(2*(batch_size-1))
177+
u2 = (1 - gamma) * self.u[index].cuda() + gamma * torch.sum(neg_logits2, dim=1, keepdim=True)/(2*(batch_size-1))
143178

144179
# this sync on all devices (since "hidden" are gathering from all devices) #### maybe we can concat_all_gather index before?
145180
if self.distributed:
@@ -209,6 +244,8 @@ class GCLoss_v2(nn.Module):
209244
tau_min (float, optional): lower bound of learnable temperature in iSogCLR (default: ``0.005``)
210245
tau_max (float, optional): upper bound of learnable temperature in iSogCLR (default: ``0.05``)
211246
beta (float, optional): the momentum parameter for updating temperature parameters in iSogCLR (default: ``0.9``)
247+
gamma_schedule (str, optional): the schedule for gamma. Options are 'constant' (fixed ``gamma``) and 'cosine' (decaying from 1.0 to ``gamma``) (default: ``'cosine'``)
248+
gamma_decay_epochs (int, optional): After this number of epochs, gamma will decrease to the value set by the option ``gamma``. Used only when gamma_schedule is 'cosine'. We recommend a value of total_training_epochs // 2 (default: ``-1``)
212249
213250
214251
Example:
@@ -237,7 +274,10 @@ def __init__(
237274
world_size=1,
238275
distributed=False,
239276
enable_isogclr=False,
240-
tau_min=0.005, tau_max=0.05, rho=6.0, eta=0.01, beta=0.9):
277+
tau_min=0.005, tau_max=0.05, rho=6.0, eta=0.01, beta=0.9,
278+
gamma_schedule='constant',
279+
gamma_decay_epochs=-1,
280+
):
241281
super(GCLoss_v2, self).__init__()
242282
self.cache_labels = cache_labels
243283
self.rank = rank
@@ -255,7 +295,8 @@ def __init__(
255295
# sogclr
256296
self.u1 = torch.zeros(N).reshape(-1, 1).detach()
257297
self.u2 = torch.zeros(N).reshape(-1, 1).detach()
258-
self.gamma = gamma
298+
self.gamma_min = gamma
299+
self.gamma = 1.0
259300
self.tau = tau
260301

261302
self.eps = 1e-20
@@ -272,6 +313,32 @@ def __init__(
272313
self.grad_tau_img = torch.zeros(N).reshape(-1, 1)
273314
self.grad_tau_txt = torch.zeros(N).reshape(-1, 1)
274315

316+
self.gamma_schedule = gamma_schedule
317+
assert self.gamma_schedule in ["constant", "cosine"]
318+
if self.gamma_schedule == "cosine":
319+
assert gamma_decay_epochs > 0
320+
self.gamma_decay_epochs = gamma_decay_epochs
321+
gamma_str = f"Using {self.gamma_schedule} schedule for gamma"
322+
if self.gamma_schedule == "constant":
323+
gamma_str += f" with gamma = {self.gamma_min}"
324+
else:
325+
gamma_str += f" with gamma_min = {self.gamma_min}, gamma_decay_epochs = {self.gamma_decay_epochs}"
326+
print(gamma_str)
327+
328+
def adjust_gamma(self, epoch: int):
329+
"""Adjust gamma for dynamic loss according to its schedule."""
330+
if self.gamma_schedule == "constant":
331+
if epoch == 0:
332+
self.gamma = 1.0
333+
else:
334+
self.gamma = self.gamma_min
335+
elif self.gamma_schedule == "cosine":
336+
if epoch < self.gamma_decay_epochs:
337+
self.gamma = (1- self.gamma_min) * 0.5 * (1 + np.cos(np.pi * epoch / self.gamma_decay_epochs)) + self.gamma_min
338+
else:
339+
self.gamma = self.gamma_min
340+
print(f"Epoch: {epoch}, gamma: {self.gamma:.3f}")
341+
275342
def forward(self, image_features, text_features, index):
276343
device = image_features.device
277344

@@ -328,17 +395,23 @@ def forward(self, image_features, text_features, index):
328395
neg_logits_text = torch.exp(logits_text_d_tau - self.b2[index].to(device))*neg_mask #(B, 4B)
329396

330397
# u init
331-
if self.u1[index].sum() == 0:
332-
u1 = torch.sum(neg_logits_image, dim=1, keepdim=True)/(large_batch_size-1)
398+
if self.gamma_schedule == "constant" and self.gamma > self.gamma_min:
399+
if self.u1[index].sum() == 0:
400+
gamma1 = 1.0
401+
else:
402+
gamma1 = self.gamma_min
403+
if self.u2[index].sum() == 0:
404+
gamma2 = 1.0
405+
else:
406+
gamma2 = self.gamma_min
333407
else:
334-
u1 = (1 - self.gamma) * self.u1[index].to(device) * torch.exp(old_b1 - self.b1[index].to(device)) \
335-
+ self.gamma * torch.sum(neg_logits_image, dim=1, keepdim=True)/(large_batch_size-1)
408+
gamma1 = self.gamma
409+
gamma2 = self.gamma
410+
u1 = (1 - gamma1) * self.u1[index].to(device) * torch.exp(old_b1 - self.b1[index].to(device)) \
411+
+ gamma1 * torch.sum(neg_logits_image, dim=1, keepdim=True)/(large_batch_size-1)
336412

337-
if self.u2[index].sum() == 0:
338-
u2 = torch.sum(neg_logits_text, dim=1, keepdim=True)/(large_batch_size-1)
339-
else:
340-
u2 = (1 - self.gamma) * self.u2[index].to(device) * torch.exp(old_b2 - self.b2[index].to(device)) \
341-
+ self.gamma * torch.sum(neg_logits_text, dim=1, keepdim=True)/(large_batch_size-1)
413+
u2 = (1 - gamma2) * self.u2[index].to(device) * torch.exp(old_b2 - self.b2[index].to(device)) \
414+
+ gamma2 * torch.sum(neg_logits_text, dim=1, keepdim=True)/(large_batch_size-1)
342415

343416
u1 = u1.clamp(min=self.eps)
344417
u2 = u2.clamp(min=self.eps)
@@ -412,6 +485,9 @@ def get_loss(self, mode='unimodal', **kwargs):
412485

413486
def forward(self, hidden1, hidden2, index, **kwargs):
414487
return self.loss_fn(hidden1, hidden2, index, **kwargs)
488+
489+
def adjust_gamma(self, epoch: int):
490+
self.loss_fn.adjust_gamma(epoch)
415491

416492

417493
# utils

0 commit comments

Comments
 (0)