Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 4 additions & 23 deletions modelopt/torch/distill/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,6 @@ def _set_input_tensor(self, input_tensors: list[Tensor]):

# HACK: Concatenate output tensors when PP>1 so they can be passed between ranks.
def _forward(self, *args, **kwargs):
if not self.training:
with self.only_student_forward():
return type(self).forward(self, *args, **kwargs)

with torch.no_grad():
self._teacher_model.eval()
teacher_output = self._teacher_model(*args, **kwargs)
Expand All @@ -551,20 +547,15 @@ def _forward(self, *args, **kwargs):


def get_tensor_shapes_adjust_fn_for_distillation(
model: torch.nn.Module | list[torch.nn.Module],
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int | None = None,
forward_only: bool = False,
model: torch.nn.Module | list[torch.nn.Module], **kwargs
) -> Callable | None:
"""Return the function to adjust tensor shapes for Distillation in Megatron-Core's forward pass.
Comment on lines +550 to 552
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid breaking API: keep positional args (backward‑compat) and funnel into kwargs.

Switching to only **kwargs will break any existing positional call sites with TypeError: too many positional arguments. Keep the old parameters (defaulted/optional), merge them into kwargs, and deprecate later.

Apply this diff to the signature:

-def get_tensor_shapes_adjust_fn_for_distillation(
-    model: torch.nn.Module | list[torch.nn.Module], **kwargs
-) -> Callable | None:
+def get_tensor_shapes_adjust_fn_for_distillation(
+    model: torch.nn.Module | list[torch.nn.Module],
+    seq_length: int | None = None,
+    micro_batch_size: int | None = None,
+    decoder_seq_length: int | None = None,
+    forward_only: bool | None = None,
+    **kwargs,
+) -> Callable | None:

Add this merge shim at the top of the function body:

# Back‑compat: funnel explicit args into kwargs if provided.
if seq_length is not None:
    kwargs.setdefault("seq_length", seq_length)
if micro_batch_size is not None:
    kwargs.setdefault("micro_batch_size", micro_batch_size)
if decoder_seq_length is not None:
    kwargs.setdefault("decoder_seq_length", decoder_seq_length)
if forward_only is not None:
    kwargs.setdefault("forward_only", forward_only)
🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 550-552, the function
signature was changed to only accept **kwargs which breaks backward
compatibility for callers using positional parameters; restore the original
explicit parameters (seq_length, micro_batch_size, decoder_seq_length,
forward_only) as optional/defaulted parameters in the signature, and at the top
of the function body add a back-compat shim that funnels any provided explicit
args into kwargs (using kwargs.setdefault) so existing positional call sites
continue to work; mark these explicit params as deprecated in a comment for
future removal.


Currently only used during non-interleaved pipelining for Distillation.
Concatenates sizes of student and teacher output tensors for inter-process communication.
"""
if (
forward_only
or parallel_state.get_pipeline_model_parallel_world_size() == 1
parallel_state.get_pipeline_model_parallel_world_size() == 1
or parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None
):
return None
Expand All @@ -584,20 +575,10 @@ def adjust_tensor_shapes(
cp_group = parallel_state.get_context_parallel_group()

teacher_recv_tensor_shapes = get_tensor_shapes(
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=teacher_config,
tp_group=tp_group,
cp_group=cp_group,
config=teacher_config, tp_group=tp_group, cp_group=cp_group, **kwargs
)
teacher_send_tensor_shapes = get_tensor_shapes(
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=teacher_config,
tp_group=tp_group,
cp_group=cp_group,
config=teacher_config, tp_group=tp_group, cp_group=cp_group, **kwargs
)
model.set_student_input_tensor_shape(recv_tensor_shapes)

Expand Down
Loading