Skip to content

Commit 606ca4e

Browse files
Option to return hard and soft loss when using distillation (#895)
## Summary Proposition to add `return_soft_hard_loss` parameter to enable logging of soft and hard losses separately. Useful for monitoring and analysis during training ## Testing Done - [x] test_jsd_loss.py - [x] test_cosine_loss.py --------- Co-authored-by: Shao Tang <tangshao28@gmail.com>
1 parent d5648bf commit 606ca4e

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

src/liger_kernel/chunked_loss/cosine_similarity_loss.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Tuple
2+
from typing import Union
3+
14
import torch
25
import torch.nn.functional as F
36

@@ -41,7 +44,8 @@ def forward(
4144
temperature: float = 1.0,
4245
compiled: bool = True,
4346
chunk_size: int = 1024,
44-
):
47+
return_soft_hard_loss: bool = False,
48+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
4549
return super().forward(
4650
cls=cls,
4751
ctx=ctx,
@@ -59,11 +63,12 @@ def forward(
5963
ignore_index=ignore_index,
6064
temperature=temperature,
6165
compiled=compiled,
66+
return_soft_hard_loss=return_soft_hard_loss,
6267
)
6368

6469
@staticmethod
65-
def backward(ctx, grad_output):
66-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
70+
def backward(ctx, grad_output, *args):
71+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
6772

6873
return (
6974
*grads,
@@ -75,6 +80,7 @@ def backward(ctx, grad_output):
7580
None, # temperature
7681
None, # compiled
7782
None, # chunk_size
83+
None, # return_soft_hard_loss
7884
)
7985

8086

@@ -88,6 +94,7 @@ def __init__(
8894
temperature: float = 1.0,
8995
compiled: bool = True,
9096
chunk_size: int = 1024,
97+
return_soft_hard_loss: bool = False,
9198
):
9299
super().__init__()
93100
assert temperature != 0, "Temperature cannot be 0."
@@ -98,6 +105,7 @@ def __init__(
98105
self.compiled = compiled
99106
self.beta = beta
100107
self.chunk_size = chunk_size
108+
self.return_soft_hard_loss = return_soft_hard_loss
101109

102110
def forward(
103111
self,
@@ -108,7 +116,7 @@ def forward(
108116
true_labels: torch.LongTensor,
109117
student_bias: torch.Tensor = None,
110118
teacher_bias: torch.Tensor = None,
111-
) -> torch.Tensor:
119+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
112120
return LigerFusedLinearCosineSimilarityFunction.apply(
113121
student_input,
114122
student_weight,
@@ -124,4 +132,5 @@ def forward(
124132
self.temperature,
125133
self.compiled,
126134
self.chunk_size,
135+
self.return_soft_hard_loss,
127136
)

src/liger_kernel/chunked_loss/fused_linear_distillation.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from abc import abstractmethod
22
from functools import partial
3+
from typing import Tuple
4+
from typing import Union
35

46
import torch
57

@@ -157,8 +159,9 @@ def forward(
157159
compute_ce_loss=True,
158160
temperature=1.0,
159161
compiled=True,
162+
return_soft_hard_loss=False,
160163
**loss_kwargs,
161-
):
164+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
162165
"""
163166
Base class for fused linear layer with distillation loss.
164167
Only need to compute gradients for student model.
@@ -180,13 +183,16 @@ def forward(
180183
compute_ce_loss (bool): Whether to compute CE loss.
181184
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
182185
compiled (bool): Whether to use torch compile for chunk accumulation.
186+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
183187
loss_kwargs (dict): Other possible arguments that a loss function might need
184188
"""
185189
CHUNK_SIZE = chunk_size
186190
grad_weight = torch.zeros_like(student_weight)
187191
grad_inputs = []
188192
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
189193
loss_acc = torch.zeros((), device=student_input.device)
194+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
195+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
190196

191197
loss_func_to_call = partial(
192198
LigerFusedLinearDistillationBase._compute_loss,
@@ -247,6 +253,9 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
247253
)
248254
grad_weight.add_(chunk_grad_weight)
249255
loss_acc.add_(chunk_loss)
256+
if return_soft_hard_loss:
257+
soft_loss_acc.add_(chunk_soft_loss)
258+
hard_loss_acc.add_(chunk_hard_loss)
250259
return chunk_grad_input
251260

252261
if compiled:
@@ -268,10 +277,12 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
268277
grad_weight,
269278
grad_bias,
270279
)
280+
if return_soft_hard_loss:
281+
return loss_acc, soft_loss_acc, hard_loss_acc
271282
return loss_acc
272283

273284
@staticmethod
274-
def backward(ctx, grad_output):
285+
def backward(ctx, grad_output, *args):
275286
grad_input, grad_weight, grad_bias = ctx.saved_tensors
276287
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
277288
grad_input = grad_input * grad_output

src/liger_kernel/chunked_loss/jsd_loss.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import math
22

3+
from typing import Tuple
4+
from typing import Union
5+
36
import torch
47
import torch.nn.functional as F
58

@@ -56,6 +59,7 @@ def forward(
5659
temperature: float = 1.0,
5760
compiled: bool = True,
5861
chunk_size: int = 1024,
62+
return_soft_hard_loss: bool = False,
5963
):
6064
"""
6165
Fused linear layer with JSD distillation loss.
@@ -72,8 +76,9 @@ def forward(
7276
temperature (float): Temperature for softening/sharpening distributions
7377
compiled (bool): Whether to use torch compile
7478
chunk_size (int): Size of chunks for processing.
79+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
7580
Returns:
76-
torch.Tensor: Computed loss
81+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
7782
"""
7883
return super().forward(
7984
cls=cls,
@@ -92,11 +97,12 @@ def forward(
9297
ignore_index=ignore_index,
9398
temperature=temperature,
9499
compiled=compiled,
100+
return_soft_hard_loss=return_soft_hard_loss,
95101
)
96102

97103
@staticmethod
98-
def backward(ctx, grad_output):
99-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
104+
def backward(ctx, grad_output, *args):
105+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
100106

101107
return (
102108
*grads,
@@ -108,6 +114,7 @@ def backward(ctx, grad_output):
108114
None, # temperature
109115
None, # compiled
110116
None, # chunk_size
117+
None, # return_soft_hard_loss
111118
)
112119

113120

@@ -125,6 +132,7 @@ def __init__(
125132
temperature: float = 1.0,
126133
compiled: bool = True,
127134
chunk_size: int = 1024,
135+
return_soft_hard_loss: bool = False,
128136
):
129137
"""
130138
Args:
@@ -135,6 +143,7 @@ def __init__(
135143
compiled (bool): Whether to use torch compile
136144
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
137145
chunk_size (int): Size of chunks for processing.
146+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
138147
"""
139148
super().__init__()
140149
assert temperature != 0, "Temperature cannot be 0."
@@ -145,6 +154,7 @@ def __init__(
145154
self.compiled = compiled
146155
self.beta = beta
147156
self.chunk_size = chunk_size
157+
self.return_soft_hard_loss = return_soft_hard_loss
148158

149159
def forward(
150160
self,
@@ -155,7 +165,7 @@ def forward(
155165
true_labels: torch.LongTensor,
156166
student_bias: torch.Tensor = None,
157167
teacher_bias: torch.Tensor = None,
158-
) -> torch.Tensor:
168+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
159169
"""
160170
Compute the JSD distillation loss.
161171
@@ -167,7 +177,9 @@ def forward(
167177
true_labels (torch.LongTensor): Target labels tensor
168178
169179
Returns:
170-
torch.Tensor: Computed loss
180+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181+
If return_soft_hard_loss is False: Computed combined loss
182+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
171183
"""
172184
return LigerFusedLinearJSDFunction.apply(
173185
student_input,
@@ -184,4 +196,5 @@ def forward(
184196
self.temperature,
185197
self.compiled,
186198
self.chunk_size,
199+
self.return_soft_hard_loss,
187200
)

0 commit comments

Comments
 (0)