@@ -532,10 +532,6 @@ def _set_input_tensor(self, input_tensors: list[Tensor]):
532532
533533 # HACK: Concatenate output tensors when PP>1 so they can be passed between ranks.
534534 def _forward (self , * args , ** kwargs ):
535- if not self .training :
536- with self .only_student_forward ():
537- return type (self ).forward (self , * args , ** kwargs )
538-
539535 with torch .no_grad ():
540536 self ._teacher_model .eval ()
541537 teacher_output = self ._teacher_model (* args , ** kwargs )
@@ -551,20 +547,15 @@ def _forward(self, *args, **kwargs):
551547
552548
553549def get_tensor_shapes_adjust_fn_for_distillation (
554- model : torch .nn .Module | list [torch .nn .Module ],
555- seq_length : int ,
556- micro_batch_size : int ,
557- decoder_seq_length : int | None = None ,
558- forward_only : bool = False ,
550+ model : torch .nn .Module | list [torch .nn .Module ], ** kwargs
559551) -> Callable | None :
560552 """Return the function to adjust tensor shapes for Distillation in Megatron-Core's forward pass.
561553
562554 Currently only used during non-interleaved pipelining for Distillation.
563555 Concatenates sizes of student and teacher output tensors for inter-process communication.
564556 """
565557 if (
566- forward_only
567- or parallel_state .get_pipeline_model_parallel_world_size () == 1
558+ parallel_state .get_pipeline_model_parallel_world_size () == 1
568559 or parallel_state .get_virtual_pipeline_model_parallel_world_size () is not None
569560 ):
570561 return None
@@ -584,20 +575,10 @@ def adjust_tensor_shapes(
584575 cp_group = parallel_state .get_context_parallel_group ()
585576
586577 teacher_recv_tensor_shapes = get_tensor_shapes (
587- seq_length = seq_length ,
588- micro_batch_size = micro_batch_size ,
589- decoder_seq_length = decoder_seq_length ,
590- config = teacher_config ,
591- tp_group = tp_group ,
592- cp_group = cp_group ,
578+ config = teacher_config , tp_group = tp_group , cp_group = cp_group , ** kwargs
593579 )
594580 teacher_send_tensor_shapes = get_tensor_shapes (
595- seq_length = seq_length ,
596- micro_batch_size = micro_batch_size ,
597- decoder_seq_length = decoder_seq_length ,
598- config = teacher_config ,
599- tp_group = tp_group ,
600- cp_group = cp_group ,
581+ config = teacher_config , tp_group = tp_group , cp_group = cp_group , ** kwargs
601582 )
602583 model .set_student_input_tensor_shape (recv_tensor_shapes )
603584
0 commit comments