Skip to content

Commit 4a23972

Browse files
authored
[Main] Add the missing part to support 1F1B overlap for Qwen3-Next (#2997)
1 parent 473e283 commit 4a23972

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

megatron/core/ssm/gated_delta_net.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,19 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_gr
508508

509509
return sharded_state_dict
510510

511+
def backward_dw(self):
512+
"""Execute weight gradient computation for all linear layers."""
513+
self._backward_in_proj()
514+
self._backward_out_proj()
515+
516+
def _backward_in_proj(self):
517+
"""Computes weight gradients of input projection layer."""
518+
self.in_proj.backward_dw()
519+
520+
def _backward_out_proj(self):
521+
"""Computes weight gradients of output projection layer."""
522+
self.out_proj.backward_dw()
523+
511524

512525
def _split_tensor_factory(
513526
orig_sh_ten: ShardedTensor, split_sections: List[int], split_names: List[str], split_dim: int

0 commit comments

Comments
 (0)