Skip to content

Commit 1f24add

Browse files
lint fix
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com>
1 parent 1fa3b30 commit 1f24add

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

nemo/collections/llm/modelopt/distill/loss.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,11 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
238238

239239

240240
class MFTLoss(BaseLoss):
241-
"""Calculates the Minifinetuning loss between two logits tensors and with the presence of labels without reducing the sequence dim. This function implements the distillation loss found in the paper: https://arxiv.org/abs/2506.15702."""
241+
"""
242+
Calculates the Minifinetuning loss between two logits tensors and with the presence of labels
243+
without reducing the sequence dim. This function implements the distillation loss found in the
244+
paper: https://arxiv.org/abs/2506.15702.
245+
"""
242246

243247
def __init__(
244248
self, model_config: "TransformerConfig", threshold: float, temperature: float = 1.0, reverse: bool = False
@@ -247,7 +251,8 @@ def __init__(
247251
248252
Args:
249253
model_config: MCore transformer config.
250-
threshold: Threshold for the MFT loss, used to determine the correction factor for the teacher probability given the ground truth labels.
254+
threshold: Threshold for the MFT loss, used to determine the correction factor
255+
for the teacher probability given the ground truth labels.
251256
temperature: Divide tensors by this value prior to calculating loss.
252257
reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher)
253258
"""
@@ -266,8 +271,8 @@ def _prepare_corrected_distributions(
266271
"""Prepare the corrected distributions for MFT loss.
267272
268273
Args:
269-
logits: The logits from the teacher model, shape (batch, channels) # e.g. (batch_size * seq_len, vocab_size)
270-
in case of LMs
274+
logits: The logits from the teacher model, shape (batch, channels)
275+
# e.g. (batch_size * seq_len, vocab_size) in case of LMs
271276
labels: The ground truth labels, shape (batch) # e.g. (batch_size * seq_len) in case of LMs
272277
threshold: The threshold value for the MFT correction.
273278
apply_threshold_to_all: If True, apply the threshold correction to all tokens,

nemo/collections/llm/modelopt/distill/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class DistillationConfig:
5454
skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``).
5555
kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``.
5656
use_mft: Whether to use MFT (Minifinetuning) for distillation.
57-
mft_threshold: Threshold for MFT loss, used to determine the correction factor for the teacher probability given the ground truth labels.
57+
mft_threshold: Threshold for MFT loss, used to determine the correction factor
58+
for the teacher probability given the ground truth labels.
5859
"""
5960

6061
intermediate_layer_pairs: List[Tuple[str, str]] = field(default_factory=list)

0 commit comments

Comments
 (0)