-
Notifications
You must be signed in to change notification settings - Fork 1
state_dict saving #19
base: master
Are you sure you want to change the base?
Conversation
samgd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest making the following changes to Saver/CallbackHandler/fit in order to
make the code more modular by adhering to the separation of concerns design
principle. This will make the code is simpler to modify and maintain in the
future!
Each of the three components has a reasonably well defined purpose and
interface:
- The Saver (should) handle everything to do with saving and restoring
state to and from disk.
- The CallbackHandler is responsible for creating the initial state and
running the Callbacks at the appropriate times.
- The fit function orchestrates the training and inference process.
The fit function uses the CallbackHandler's interface to allow
modification to this process where the modifications are Callbacks. The
CallbackHandler uses the Callback interface to run each Callback in turn.
The changes suggested aim to keep these interfaces distinct:
- The CallbackHandler can initialise the initial_epoch and
total_train_batches to 0. The `Saver.on_train_begin` method can overwrite
these values if it has loaded state from disk using the standard
`Callback` functionality. This means the CallbackHandler's init arguments
and logic remains unchanged (no need for `epoch` or
`total_train_batches`).
- fit does not need to know anything about the `training_state`. The Saver
Callback is responsible for it and the `for epoch in...` loop becomes
something like:
`for epoch in range(cb_handler.state_dict["epoch"], cb_handler.state_dict["epochs"])`
- Saver needs be able to load state from disk `on_train_begin` and dump
state `on_epoch_end`.
In `on_train_begin` the state to load could either be explicitly
specified in the init function or implicitly chosen based on the last
state_dict dumped in the `log_dir`. The former is probably simpler for
now, the latter was useful in the deepspeech repo as we were using
ephemeral instances that kept restarting.
`on_epoch_end` essentially becomes `load_seq_to_seq`. It might be worth
removing the parsing of the epoch from the filename and only allowing it
to be parsed from the loaded state dict.
|
I've followed all suggested changes except for the fact that there is an Putting up for re-review. |
samgd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the lr_scheduler in master? If it is can this PR be rebased on it (I can't see it!). If it isn't, can the lr_scheduler code be pushed to a different PR?
Now that the Saver functionality is more complex should we add a test(s)? A simple create tiny module, create temporary directory, run a bunch of epochs with a tiny amount of data, make sure the right state_dict is restored?
I think my first comment on the
master does not have my I will respond to the rest of the above comments shortly but just flagging this high level point now |
Ah I see now, thanks. The |
|
Responded to changes and added a test. I've added a |
Also on the above point - in the |
samgd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments, but nearly there!
| """ | ||
| fnames = [] | ||
| for fname in os.listdir(self.log_dir): | ||
| match = re.findall(r"state_dict_(\d+)\.pt", fname) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following is my mistake, sorry, but findall should actually be fullmatch:
| match = re.findall(r"state_dict_(\d+)\.pt", fname) | |
| match = re.fullmatch(r"state_dict_(\d+)\.pt", fname) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I actually want the findall version since this specifically returns the epoch number (i.e. it matches the whole regex but returns the matches in parentheses).
with this I can do:
if match:
fnames.append((int(match[0]), fname))
fnames.sort()And I later check that the epoch in the filename is the same as the epoch in the state dict and raise a warning if not.
tests/run/test_Saver.py
Outdated
| # Utilities ------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def change_seq_to_seq_state_dict(seq_to_seq: SeqToSeq): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name is a little opaque, could make clearer if you want to update or have any suggestions (I can't think of any, naming is hard!). Name change aside, perhaps adding an underscore to signify in-place and updating the top level description?
| def change_seq_to_seq_state_dict(seq_to_seq: SeqToSeq): | |
| def change_seq_to_seq_state_dict_(seq_to_seq: SeqToSeq) -> None: | |
| """Modifies all values in ``seq_to_seq`` such that they are different""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to alter_seq2seq_attributes_?
| seq_to_seq_cfg | ||
| ) | ||
|
|
||
| with TemporaryDirectory() as tmpdir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytest has a built-in fixture for temporary directories: https://docs.pytest.org/en/latest/tmpdir.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was using tmpdir originally but it interacts with hypothesis in an undesirable way. Namely, a single tmpdir is shared between all hypothesis calls to the same function. In this case that causes the test to fail when it shouldn't as the directory contents affect Savers loading behaviour.
I could add teardown code but I think that removes the simplicity that tmpdir usually provides.
tests/run/test_Saver.py
Outdated
| with pytest.raises(AssertionError): | ||
| check_state_dicts_match( | ||
| expected_sd["model"], seq_to_seq.state_dict()["model"] | ||
| ) | ||
| if seq_to_seq.optim: | ||
| with pytest.raises(AssertionError): | ||
| check_state_dicts_match( | ||
| expected_sd["optim"], seq_to_seq.state_dict()["optim"] | ||
| ) | ||
| if seq_to_seq.lr_scheduler: | ||
| with pytest.raises(AssertionError): | ||
| check_state_dicts_match( | ||
| expected_sd["lr_scheduler"], | ||
| seq_to_seq.state_dict()["lr_scheduler"], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check_state_dicts_match function is recursive so these three calls could be merged into one? The check_state_dicts_match function could also return a bool rather than throwing an assertion making it more natural to call?
assert state_dicts_match(expected_sd, seq_to_seq.state_dict())with check_state_dicts_match becoming something like (untested):
def state_dicts_match(dict1, dict2) -> bool:
"""Returns True if dicts have same keys and values."""
if not dict1.keys() == dict2.keys():
return False
for key in dict1.keys():
val1 = dict1[key]
val2 = dict2[key]
if isinstance(val1, dict):
if not (isinstance(val2, dict) and state_dicts_match(val1, val2)):
return False
if isinstance(val1, float):
if not (isinstance(val2, float) and math.isclose(val1, val2)):
return False
if isinstance(val1, torch.Tensor):
if not (isinstance(val2, torch.Tensor) and torch.allclose(val1, val2)):
return False
return TrueThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated it to return a bool but I'm intentionally checking that all of the subdicts don't match whereas assert state_dicts_match(expected_sd, seq_to_seq.state_dict()) will fail if just one of them was changed.
|
Responded to comments. A few outstanding discussions. |
Small PR updating the
seq_to_seqstate dict to include the learning-rate scheduler and optimizer state so that crashed training runs can be resumed w/o disruption.I've also added a helper function that, given a state_dict filepath, loads the dict into a
seq_to_seqinstance and returns the training state necessary for theCallbackHandler. It needs this state so that theSavercallback can save with naming conventionstate_dict_xand theTensorBoardLoggercan resume from the same step.I have not added tests for this functionality here - but I have added one in the lr scheduler PR: #22 (which is blocked by the
rnntPR). The reason for this is that the state_dict restoration is relatively simple in this case but much more involved once learning-rate warm-up is added.