Skip to content

Commit c1c5ca0

Browse files
authored
KD example fix for new torch/hf causing FDSP save error (NVIDIA#645)
## What does this PR do? **Type of change:** ? Bug Fix **Overview:** ? `llm_distill` example was hanging during save since somehow now the weights on other ranks are being deleted during `model.export()` too early. Fixed via synchronizing the processes beforehand. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 9409412 commit c1c5ca0

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pyarrow
2+
torchao>=0.14.1
23
transformers<5.0
34
trl>=0.23.0

modelopt/torch/distill/plugins/huggingface.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def save_model(
6464
output_dir = self.args.output_dir
6565
model = self.accelerator.unwrap_model(self.model)
6666
if not _internal_call and self.is_fsdp_enabled:
67-
state_dict = self.accelerator.get_state_dict(self.model)
67+
with model.hide_teacher_model(enable=export_student):
68+
state_dict = self.accelerator.get_state_dict(self.model)
6869
modelopt_state = mto.modelopt_state(model)
6970
if export_student:
71+
# Need to wait, otherwise FSDP weights may be deleted before rank 0 can gather them
72+
self.accelerator.wait_for_everyone()
7073
model = model.export()
71-
# remove teacher model from state dict since FSDP forces
72-
# expose_minimal_state_dict to be False
73-
state_dict = {k: v for k, v in state_dict.items() if "_teacher_model" not in k}
7474

7575
if self.accelerator.is_main_process:
7676
model.save_pretrained(

0 commit comments

Comments
 (0)