Skip to content

Commit 5db7169

Browse files
authored
Allow KD loss in val mode for MLM plugin (#331)
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 461980e commit 5db7169

File tree

1 file changed

+4
-23
lines changed

1 file changed

+4
-23
lines changed

modelopt/torch/distill/plugins/megatron.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,6 @@ def _set_input_tensor(self, input_tensors: list[Tensor]):
532532

533533
# HACK: Concatenate output tensors when PP>1 so they can be passed between ranks.
534534
def _forward(self, *args, **kwargs):
535-
if not self.training:
536-
with self.only_student_forward():
537-
return type(self).forward(self, *args, **kwargs)
538-
539535
with torch.no_grad():
540536
self._teacher_model.eval()
541537
teacher_output = self._teacher_model(*args, **kwargs)
@@ -551,20 +547,15 @@ def _forward(self, *args, **kwargs):
551547

552548

553549
def get_tensor_shapes_adjust_fn_for_distillation(
554-
model: torch.nn.Module | list[torch.nn.Module],
555-
seq_length: int,
556-
micro_batch_size: int,
557-
decoder_seq_length: int | None = None,
558-
forward_only: bool = False,
550+
model: torch.nn.Module | list[torch.nn.Module], **kwargs
559551
) -> Callable | None:
560552
"""Return the function to adjust tensor shapes for Distillation in Megatron-Core's forward pass.
561553
562554
Currently only used during non-interleaved pipelining for Distillation.
563555
Concatenates sizes of student and teacher output tensors for inter-process communication.
564556
"""
565557
if (
566-
forward_only
567-
or parallel_state.get_pipeline_model_parallel_world_size() == 1
558+
parallel_state.get_pipeline_model_parallel_world_size() == 1
568559
or parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None
569560
):
570561
return None
@@ -584,20 +575,10 @@ def adjust_tensor_shapes(
584575
cp_group = parallel_state.get_context_parallel_group()
585576

586577
teacher_recv_tensor_shapes = get_tensor_shapes(
587-
seq_length=seq_length,
588-
micro_batch_size=micro_batch_size,
589-
decoder_seq_length=decoder_seq_length,
590-
config=teacher_config,
591-
tp_group=tp_group,
592-
cp_group=cp_group,
578+
config=teacher_config, tp_group=tp_group, cp_group=cp_group, **kwargs
593579
)
594580
teacher_send_tensor_shapes = get_tensor_shapes(
595-
seq_length=seq_length,
596-
micro_batch_size=micro_batch_size,
597-
decoder_seq_length=decoder_seq_length,
598-
config=teacher_config,
599-
tp_group=tp_group,
600-
cp_group=cp_group,
581+
config=teacher_config, tp_group=tp_group, cp_group=cp_group, **kwargs
601582
)
602583
model.set_student_input_tensor_shape(recv_tensor_shapes)
603584

0 commit comments

Comments
 (0)