Skip to content

[Question] Do use_liger_gkd_loss breaks DDP gradient sync by unwrapping student model? #5282

@JeffLee1874

Description

@JeffLee1874

In the GKD and GOLD trainers, when use_liger_gkd_loss is enabled, the code explicitly unwraps the student model before the forward pass:

unwrapped_student = self.accelerator.unwrap_model(model)
if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None:
    base_student = unwrapped_student.get_decoder()
else:
    base_student = getattr(
        unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student
    )

student_outputs = base_student(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    use_cache=False,
)

Isn't this problematic for distributed training? By bypassing the DistributedDataParallel wrapper for the student's forward pass, the gradients won't be synchronized across GPUs via AllReduce, causing the model weights on different ranks to diverge. Shouldn't the student model be called directly (e.g., model(...)) to preserve the DDP graph?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions