Skip to content
This repository was archived by the owner on Apr 29, 2021. It is now read-only.

Conversation

@julianmack
Copy link
Contributor

@julianmack julianmack commented Jan 9, 2020

Small PR updating the seq_to_seq state dict to include the learning-rate scheduler and optimizer state so that crashed training runs can be resumed w/o disruption.

I've also added a helper function that, given a state_dict filepath, loads the dict into a seq_to_seq instance and returns the training state necessary for the CallbackHandler. It needs this state so that the Saver callback can save with naming convention state_dict_x and the TensorBoardLogger can resume from the same step.

I have not added tests for this functionality here - but I have added one in the lr scheduler PR: #22 (which is blocked by the rnnt PR). The reason for this is that the state_dict restoration is relatively simple in this case but much more involved once learning-rate warm-up is added.

@julianmack julianmack requested a review from samgd January 15, 2020 10:35
@julianmack julianmack changed the title Adding full state saving state_dict saving Jan 15, 2020
Copy link
Contributor

@samgd samgd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest making the following changes to Saver/CallbackHandler/fit in order to
make the code more modular by adhering to the separation of concerns design
principle. This will make the code is simpler to modify and maintain in the
future!

Each of the three components has a reasonably well defined purpose and
interface:

- The Saver (should) handle everything to do with saving and restoring
  state to and from disk.

- The CallbackHandler is responsible for creating the initial state and
  running the Callbacks at the appropriate times.

- The fit function orchestrates the training and inference process.

  The fit function uses the CallbackHandler's interface to allow
  modification to this process where the modifications are Callbacks. The
  CallbackHandler uses the Callback interface to run each Callback in turn.

The changes suggested aim to keep these interfaces distinct:

- The CallbackHandler can initialise the initial_epoch and
  total_train_batches to 0. The `Saver.on_train_begin` method can overwrite
  these values if it has loaded state from disk using the standard
  `Callback` functionality. This means the CallbackHandler's init arguments
  and logic remains unchanged (no need for `epoch` or
  `total_train_batches`).

- fit does not need to know anything about the `training_state`. The Saver
  Callback is responsible for it and the `for epoch in...` loop becomes
  something like:

  `for epoch in range(cb_handler.state_dict["epoch"], cb_handler.state_dict["epochs"])`

- Saver needs be able to load state from disk `on_train_begin` and dump
  state `on_epoch_end`.

  In `on_train_begin` the state to load could either be explicitly
  specified in the init function or implicitly chosen based on the last
  state_dict dumped in the `log_dir`. The former is probably simpler for
  now, the latter was useful in the deepspeech repo as we were using
  ephemeral instances that kept restarting.

  `on_epoch_end` essentially becomes `load_seq_to_seq`. It might be worth
  removing the parsing of the epoch from the filename and only allowing it
  to be parsed from the loaded state dict.

@julianmack
Copy link
Contributor Author

I've followed all suggested changes except for the fact that there is an lr_scheduler in master so I've kept this and treated accordingly.

Putting up for re-review.

@julianmack julianmack requested a review from samgd January 28, 2020 10:11
Copy link
Contributor

@samgd samgd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the lr_scheduler in master? If it is can this PR be rebased on it (I can't see it!). If it isn't, can the lr_scheduler code be pushed to a different PR?

Now that the Saver functionality is more complex should we add a test(s)? A simple create tiny module, create temporary directory, run a bunch of epochs with a tiny amount of data, make sure the right state_dict is restored?

@julianmack
Copy link
Contributor Author

julianmack commented Jan 29, 2020

Is the lr_scheduler in master? If it is can this PR be rebased on it (I can't see it!). If it isn't, can the lr_scheduler code be pushed to a different PR?

I think my first comment on the lr_scheduler was unclear @samgd. There is already a lr_scheduler in master but not added by me (added by Giuseppe a while back):

# create learning rate scheduler

master does not have my lr_warmup and polynomial_lr changes: these are in an incoming PR (#22) and I agree that they should be considered separately!

I will respond to the rest of the above comments shortly but just flagging this high level point now

@samgd
Copy link
Contributor

samgd commented Jan 29, 2020

I think my first comment on the lr_scheduler was unclear @samgd. There is already a lr_scheduler in master but not added by me (added by Giuseppe a while back):

Ah I see now, thanks. The SeqToSeq class not having an lr_scheduler parameter in init or parameter documented in the docstring was what was throwing me off - I had forgotten it was being set in the builder. I'll make an issue for this as it should be explicit and obvious!

@julianmack
Copy link
Contributor Author

julianmack commented Jan 30, 2020

Responded to changes and added a test.

I've added a check_state_dicts_match function in test.utils.utils. Although this is currently used in just the new test, I've placed it here because I want to use it in another test in my incoming lr_warmup PR.
I've also disabled the builders/test_task_config.py test as this functionality is necessarily tested in at the start of the new Saver test.

@julianmack julianmack requested a review from samgd January 30, 2020 09:35
@julianmack
Copy link
Contributor Author

Ah I see now, thanks. The SeqToSeq class not having an lr_scheduler parameter in init or parameter documented in the docstring was what was throwing me off - I had forgotten it was being set in the builder. I'll make an issue for this as it should be explicit and obvious!

Also on the above point - in the lr_warmup PR, I've refactored the lr_sheduler creation and added a proper builder w. sphinx docs.

Copy link
Contributor

@samgd samgd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments, but nearly there!

"""
fnames = []
for fname in os.listdir(self.log_dir):
match = re.findall(r"state_dict_(\d+)\.pt", fname)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following is my mistake, sorry, but findall should actually be fullmatch:

Suggested change
match = re.findall(r"state_dict_(\d+)\.pt", fname)
match = re.fullmatch(r"state_dict_(\d+)\.pt", fname)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I actually want the findall version since this specifically returns the epoch number (i.e. it matches the whole regex but returns the matches in parentheses).

with this I can do:

if match:
    fnames.append((int(match[0]), fname))
fnames.sort()

And I later check that the epoch in the filename is the same as the epoch in the state dict and raise a warning if not.

# Utilities -------------------------------------------------------------------


def change_seq_to_seq_state_dict(seq_to_seq: SeqToSeq):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name is a little opaque, could make clearer if you want to update or have any suggestions (I can't think of any, naming is hard!). Name change aside, perhaps adding an underscore to signify in-place and updating the top level description?

Suggested change
def change_seq_to_seq_state_dict(seq_to_seq: SeqToSeq):
def change_seq_to_seq_state_dict_(seq_to_seq: SeqToSeq) -> None:
"""Modifies all values in ``seq_to_seq`` such that they are different"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to alter_seq2seq_attributes_?

seq_to_seq_cfg
)

with TemporaryDirectory() as tmpdir:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest has a built-in fixture for temporary directories: https://docs.pytest.org/en/latest/tmpdir.html

Copy link
Contributor Author

@julianmack julianmack Feb 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using tmpdir originally but it interacts with hypothesis in an undesirable way. Namely, a single tmpdir is shared between all hypothesis calls to the same function. In this case that causes the test to fail when it shouldn't as the directory contents affect Savers loading behaviour.

I could add teardown code but I think that removes the simplicity that tmpdir usually provides.

Comment on lines 72 to 86
with pytest.raises(AssertionError):
check_state_dicts_match(
expected_sd["model"], seq_to_seq.state_dict()["model"]
)
if seq_to_seq.optim:
with pytest.raises(AssertionError):
check_state_dicts_match(
expected_sd["optim"], seq_to_seq.state_dict()["optim"]
)
if seq_to_seq.lr_scheduler:
with pytest.raises(AssertionError):
check_state_dicts_match(
expected_sd["lr_scheduler"],
seq_to_seq.state_dict()["lr_scheduler"],
)
Copy link
Contributor

@samgd samgd Feb 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check_state_dicts_match function is recursive so these three calls could be merged into one? The check_state_dicts_match function could also return a bool rather than throwing an assertion making it more natural to call?

assert state_dicts_match(expected_sd, seq_to_seq.state_dict())

with check_state_dicts_match becoming something like (untested):

def state_dicts_match(dict1, dict2) -> bool:
    """Returns True if dicts have same keys and values."""
    if not dict1.keys() == dict2.keys():
        return False
    for key in dict1.keys():
        val1 = dict1[key]
        val2 = dict2[key]
        if isinstance(val1, dict):
            if not (isinstance(val2, dict) and state_dicts_match(val1, val2)):
                return False
        if isinstance(val1, float):
            if not (isinstance(val2, float) and math.isclose(val1, val2)):
                return False
        if isinstance(val1, torch.Tensor):
            if not (isinstance(val2, torch.Tensor) and torch.allclose(val1, val2)):
               return False
    return True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated it to return a bool but I'm intentionally checking that all of the subdicts don't match whereas assert state_dicts_match(expected_sd, seq_to_seq.state_dict()) will fail if just one of them was changed.

@julianmack
Copy link
Contributor Author

Responded to comments. A few outstanding discussions.

@julianmack julianmack requested a review from samgd February 12, 2020 12:53
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants