We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b5e8877 commit 908355eCopy full SHA for 908355e
src/lightning/pytorch/callbacks/weight_averaging.py
@@ -289,6 +289,8 @@ def on_save_checkpoint(
289
else:
290
average_model_state = self._average_model.state_dict()
291
checkpoint["current_model_state"] = checkpoint["state_dict"]
292
+ # Truncate the "module." prefix (the first 7 characters) from the names of the variables in the
293
+ # AveragedModel state.
294
checkpoint["state_dict"] = {
295
name[7:]: value for name, value in average_model_state.items() if name.startswith("module.")
296
}
0 commit comments