@@ -532,10 +532,6 @@ def _set_input_tensor(self, input_tensors: list[Tensor]):
532
532
533
533
# HACK: Concatenate output tensors when PP>1 so they can be passed between ranks.
534
534
def _forward (self , * args , ** kwargs ):
535
- if not self .training :
536
- with self .only_student_forward ():
537
- return type (self ).forward (self , * args , ** kwargs )
538
-
539
535
with torch .no_grad ():
540
536
self ._teacher_model .eval ()
541
537
teacher_output = self ._teacher_model (* args , ** kwargs )
@@ -551,20 +547,15 @@ def _forward(self, *args, **kwargs):
551
547
552
548
553
549
def get_tensor_shapes_adjust_fn_for_distillation (
554
- model : torch .nn .Module | list [torch .nn .Module ],
555
- seq_length : int ,
556
- micro_batch_size : int ,
557
- decoder_seq_length : int | None = None ,
558
- forward_only : bool = False ,
550
+ model : torch .nn .Module | list [torch .nn .Module ], ** kwargs
559
551
) -> Callable | None :
560
552
"""Return the function to adjust tensor shapes for Distillation in Megatron-Core's forward pass.
561
553
562
554
Currently only used during non-interleaved pipelining for Distillation.
563
555
Concatenates sizes of student and teacher output tensors for inter-process communication.
564
556
"""
565
557
if (
566
- forward_only
567
- or parallel_state .get_pipeline_model_parallel_world_size () == 1
558
+ parallel_state .get_pipeline_model_parallel_world_size () == 1
568
559
or parallel_state .get_virtual_pipeline_model_parallel_world_size () is not None
569
560
):
570
561
return None
@@ -584,20 +575,10 @@ def adjust_tensor_shapes(
584
575
cp_group = parallel_state .get_context_parallel_group ()
585
576
586
577
teacher_recv_tensor_shapes = get_tensor_shapes (
587
- seq_length = seq_length ,
588
- micro_batch_size = micro_batch_size ,
589
- decoder_seq_length = decoder_seq_length ,
590
- config = teacher_config ,
591
- tp_group = tp_group ,
592
- cp_group = cp_group ,
578
+ config = teacher_config , tp_group = tp_group , cp_group = cp_group , ** kwargs
593
579
)
594
580
teacher_send_tensor_shapes = get_tensor_shapes (
595
- seq_length = seq_length ,
596
- micro_batch_size = micro_batch_size ,
597
- decoder_seq_length = decoder_seq_length ,
598
- config = teacher_config ,
599
- tp_group = tp_group ,
600
- cp_group = cp_group ,
581
+ config = teacher_config , tp_group = tp_group , cp_group = cp_group , ** kwargs
601
582
)
602
583
model .set_student_input_tensor_shape (recv_tensor_shapes )
603
584
0 commit comments