Skip to content

Conversation

@Kuonirad
Copy link

  • Modified save_train_state to save a dictionary containing model state, optimizer states, and training step.
  • Updated load_checkpoint to handle the new dict format while maintaining backward compatibility with old weight-only checkpoints.
  • Updated create_model to broadcast loaded checkpoint metadata (optimizers/step) from rank 0 to all ranks and restore optimizer states.
  • Updated init_train_state to resume training step from checkpoint.

- Modified `save_train_state` to save a dictionary containing model state, optimizer states, and training step.
- Updated `load_checkpoint` to handle the new dict format while maintaining backward compatibility with old weight-only checkpoints.
- Updated `create_model` to broadcast loaded checkpoint metadata (optimizers/step) from rank 0 to all ranks and restore optimizer states.
- Updated `init_train_state` to resume training step from checkpoint.
@Kuonirad Kuonirad closed this Dec 17, 2025
@Kuonirad Kuonirad reopened this Dec 17, 2025
- Modified `save_train_state` to save a dictionary containing model state, optimizer states, training step, and optional EMA state.
- Updated `load_checkpoint` to handle the new dict format while maintaining backward compatibility.
- Updated `create_model` to broadcast loaded checkpoint metadata (optimizers/step) from rank 0 to all ranks and restore optimizer states.
- Updated `init_train_state` to return loaded checkpoint data and resume training step.
- Updated `launch` to load EMA state if available and save the online state (plus EMA helper) instead of just the EMA weights, ensuring correct resumption.
- Implemented automatic checkpoint detection: scans `checkpoint_path` for the latest `step_X` file if no specific checkpoint is provided.
- Added full RNG state persistence (torch, cuda, numpy, random) to checkpoints to ensure deterministic resumption.
- Modified `save_train_state` and `load_checkpoint` to handle the augmented state dictionary.
- Updated `torch.load` usage to allow complex objects (`weights_only=False`) required for optimizer/RNG states.
- Cleaned up dataset configuration placeholder.
- Verified bitwise-identical resumption via script.
Copy link
Author

@Kuonirad Kuonirad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saved full training state (optimizers, step) in checkpoints

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant