Skip to content

Commit 9e280f4

Browse files
authored
Div-by-zero KD fix (#639)
## What does this PR do? **Type of change:** ? Bug fix **Overview:** ? Fix rare case of zero-loss in KD balancer for Megatron ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Asha Anoosheh <[email protected]>
1 parent ba19328 commit 9e280f4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

modelopt/torch/distill/plugins/megatron.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def forward(self, loss_dict: dict[str, Tensor]) -> Tensor:
423423
intermediate_loss = sum(loss_dict.values()) / max(len(loss_dict), 1)
424424

425425
if intermediate_loss > 0:
426-
dynamic_scale = logits_loss.item() / intermediate_loss.item()
426+
dynamic_scale = logits_loss.detach() / intermediate_loss.detach()
427427
intermediate_loss_scaled = intermediate_loss * dynamic_scale
428428
else:
429429
intermediate_loss = logits_loss.new_tensor(intermediate_loss)
@@ -433,7 +433,8 @@ def forward(self, loss_dict: dict[str, Tensor]) -> Tensor:
433433
total_loss = logits_loss + intermediate_loss_scaled
434434
else:
435435
kd_loss = logits_loss + intermediate_loss_scaled
436-
kd_loss *= original_loss.item() / kd_loss.item()
436+
if kd_loss > 0 and original_loss > 0: # zero when one CP rank has only context tokens
437+
kd_loss *= original_loss.detach() / kd_loss.detach()
437438
total_loss = original_loss + kd_loss * self._kd_loss_scale
438439

439440
out_dict = {

0 commit comments

Comments
 (0)