In the GKD and GOLD trainers, when use_liger_gkd_loss is enabled, the code explicitly unwraps the student model before the forward pass:
unwrapped_student = self.accelerator.unwrap_model(model)
if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None:
base_student = unwrapped_student.get_decoder()
else:
base_student = getattr(
unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student
)
student_outputs = base_student(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=False,
)
Isn't this problematic for distributed training? By bypassing the DistributedDataParallel wrapper for the student's forward pass, the gradients won't be synchronized across GPUs via AllReduce, causing the model weights on different ranks to diverge. Shouldn't the student model be called directly (e.g., model(...)) to preserve the DDP graph?