Skip to content

Commit 6059dfb

Browse files
authored
Bug fix for missing distillation loss arguments. (#983)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Bug fix for missing distillation loss arguments. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> Bug fix for missing distillation loss arguments. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> Unit tests. - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [] run `make test-convergence` to ensure convergence
1 parent 153d226 commit 6059dfb

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

src/liger_kernel/chunked_loss/cosine_similarity_loss.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99

1010
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
1111
@staticmethod
12-
def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
12+
def distillation_loss_fn(
13+
student_logits,
14+
teacher_logits,
15+
target=None,
16+
ignore_index=None,
17+
beta=1.0,
18+
):
1319
"""
1420
Compute Cosine loss (Cosine Similarity Loss).
1521
Args:

src/liger_kernel/chunked_loss/fused_linear_distillation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
1313
def distillation_loss_fn(
1414
student_logits,
1515
teacher_logits,
16+
target=None,
17+
ignore_index=None,
1618
):
1719
"""
1820
Compute distillation loss.

test/chunked_loss/test_cosine_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
temperature=temperature,
3434
)
3535

36-
def distillation_loss(self, student_logits, teacher_logits, beta=1.0):
36+
def distillation_loss(self, student_logits, teacher_logits, target=None, ignore_index=None, beta=1.0, **kwargs):
3737
# Compute normalized logits
3838
print(f"student_logits.shape: {student_logits.shape}")
3939
student_norm = F.normalize(student_logits, p=2, dim=-1)

0 commit comments

Comments
 (0)