-
Notifications
You must be signed in to change notification settings - Fork 9
Implement best checkpointer #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
nifarn
wants to merge
3
commits into
microsoft:main
Choose a base branch
from
nifarn:nifarn/best-checkpointer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 2 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
|
|
||
| import os | ||
| import re | ||
| from typing import Optional, Dict | ||
| from typing import Optional, Dict, Tuple, Callable | ||
| from abc import ABC, abstractmethod | ||
| from operator import itemgetter | ||
| from dataclasses import dataclass | ||
|
|
@@ -291,3 +291,99 @@ def check_mk_dir(self, dirpath: str) -> None: | |
| os.makedirs(dirpath) | ||
| assert os.path.isdir(dirpath), "supplied checkpoint dirpath "\ | ||
| "is not a directory" | ||
|
|
||
|
|
||
| @dataclass | ||
| class BestCheckpointerArguments(DefaultCheckpointerArguments): | ||
| """Additional arguments for checkpointer | ||
|
|
||
| metric_name: name of metric where minimal is defined as best. Must be a registered buffer in module interface | ||
| save_every_epoch: whether to produce a checkpointer every epoch in addition to latest and best. | ||
| load_best: whether to load best or latest checkpoint. Default behavior is to load latest. | ||
| """ | ||
| metric_name: str = "val_perplexity" | ||
| init_metric_val: Optional[float] = None | ||
| criteria: Optional[Tuple[str, Callable]] = "min" | ||
| save_every_epoch: bool = False # not usually necessary in practice | ||
| load_best: bool = False # default to load latest | ||
|
|
||
|
|
||
| class BestCheckpointer(DefaultCheckpointer): | ||
| """ | ||
| Saves best and latest checkpoint. Best checkpoint is defined as the smallest value of a given parameter in the | ||
| module interface. Therefore this checkpointer works by relying on the parameter defined in metric_name existing as a | ||
| single value. By default it checks "val_perplexity" which is a registered buffer in `AbstractUserMessageReplyModule` | ||
| that gets updated after every call to `on_end_val_epoch`. | ||
| """ | ||
| def __init__(self, args: BestCheckpointerArguments): | ||
| super().__init__(args) | ||
| self.best_checkpoint_name = f"{self.args.file_prefix}_best_checkpoint.{self.args.file_ext}" | ||
| self.latest_checkpoint_name = f"{self.args.file_prefix}_latest_checkpoint.{self.args.file_ext}" | ||
| if self.args.criteria == 'min': | ||
| self.criteria_func = lambda new, old: new < old | ||
| self.best_metric = float('inf') | ||
| elif self.args.criteria == 'max': | ||
| self.criteria_func = lambda new, old: new > old | ||
| self.best_metric = -float('inf') | ||
| else: | ||
| self.criteria_func = self.args.criteria | ||
| self.best_metric = self.args.init_metric_value | ||
|
|
||
| if self.args.init_metric_value is not None: | ||
| self.best_metric = self.args.init_metric_value | ||
|
|
||
| def save(self, checkpoint_state: Checkpoint, index: int, force=False) -> str: | ||
| """ | ||
| Saves trainer, optimizer, and module interface state. | ||
|
|
||
| Args: | ||
| checkpoint_state: instance of `Checkpoint` which contains trainer, optimizer, and module interface state | ||
| index: current epoch number | ||
| force: whether to force a save even if period of checkpointing does not line up | ||
|
|
||
| Returns: | ||
| list of paths checkpoint state was saved to | ||
| """ | ||
| paths = [] | ||
| if self.args.save_every_epoch: | ||
| paths.append(super().save(checkpoint_state, index, force)) | ||
| if self.args.checkpoint and ((index % self.args.period == 0) or force): | ||
nifarn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # TODO grab this from logged metrics instead | ||
| metric = float(checkpoint_state.module_interface_state[self.args.metric_name]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this part feels hacky if not pulled from metrics in terms of the design patterns used here, but either way I think adding logging here would be good (available metrics, metric selected and its value) |
||
|
|
||
| # optiionally save best | ||
| if self.criteria_func(metric, self.best_metric): | ||
| self.best_metric = metric | ||
| best_path = os.path.join(self.args.save_dir, self.best_checkpoint_name) | ||
| torch.save(checkpoint_state.__dict__, best_path) | ||
| paths.append(best_path) | ||
|
|
||
| # save latest | ||
| latest_path = os.path.join(self.args.save_dir, self.latest_checkpoint_name) | ||
| torch.save(checkpoint_state.__dict__, latest_path) | ||
| paths.append(latest_path) | ||
| return paths | ||
|
|
||
| def load(self) -> Checkpoint: | ||
| """ | ||
| Optionally loads a checkpoint from a given directory. Either loads a specified filename, the best checkpoint, or | ||
| the latest checkpoint. Raises a `ValueError` upon failure to load checkpoint. | ||
|
|
||
| Returns: | ||
| An instance of `Checkpoint` | ||
| """ | ||
| if self.args.load_dir: | ||
| if self.args.load_filename: | ||
| load_path = os.path.join(self.args.load_dir, self.args.load_filename) | ||
| elif self.args.load_best: | ||
| load_path = os.path.join(self.args.load_dir, self.best_checkpoint_name) | ||
| else: | ||
| load_path = os.path.join(self.args.load_dir, self.latest_checkpoint_name) | ||
|
|
||
| # TODO how to set best metric to match loaded checkpoint? | ||
| self.logger.debug(f"loading checkpoint from {load_path}") | ||
| checkpoint = torch.load(load_path, map_location=torch.device('cpu')) | ||
| self.logger.debug('Checkpoint loaded') | ||
| return Checkpoint(**checkpoint) | ||
|
|
||
| return Checkpoint() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.