diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 6e712fcb..7078cca3 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -532,10 +532,6 @@ def _set_input_tensor(self, input_tensors: list[Tensor]): # HACK: Concatenate output tensors when PP>1 so they can be passed between ranks. def _forward(self, *args, **kwargs): - if not self.training: - with self.only_student_forward(): - return type(self).forward(self, *args, **kwargs) - with torch.no_grad(): self._teacher_model.eval() teacher_output = self._teacher_model(*args, **kwargs) @@ -551,11 +547,7 @@ def _forward(self, *args, **kwargs): def get_tensor_shapes_adjust_fn_for_distillation( - model: torch.nn.Module | list[torch.nn.Module], - seq_length: int, - micro_batch_size: int, - decoder_seq_length: int | None = None, - forward_only: bool = False, + model: torch.nn.Module | list[torch.nn.Module], **kwargs ) -> Callable | None: """Return the function to adjust tensor shapes for Distillation in Megatron-Core's forward pass. @@ -563,8 +555,7 @@ def get_tensor_shapes_adjust_fn_for_distillation( Concatenates sizes of student and teacher output tensors for inter-process communication. """ if ( - forward_only - or parallel_state.get_pipeline_model_parallel_world_size() == 1 + parallel_state.get_pipeline_model_parallel_world_size() == 1 or parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None ): return None @@ -584,20 +575,10 @@ def adjust_tensor_shapes( cp_group = parallel_state.get_context_parallel_group() teacher_recv_tensor_shapes = get_tensor_shapes( - seq_length=seq_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=decoder_seq_length, - config=teacher_config, - tp_group=tp_group, - cp_group=cp_group, + config=teacher_config, tp_group=tp_group, cp_group=cp_group, **kwargs ) teacher_send_tensor_shapes = get_tensor_shapes( - seq_length=seq_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=decoder_seq_length, - config=teacher_config, - tp_group=tp_group, - cp_group=cp_group, + config=teacher_config, tp_group=tp_group, cp_group=cp_group, **kwargs ) model.set_student_input_tensor_shape(recv_tensor_shapes)