-
Notifications
You must be signed in to change notification settings - Fork 1
Sortagrad #13
base: master
Are you sure you want to change the base?
Sortagrad #13
Conversation
julianmack
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small number of small refactoring and docstring comments
| batch_sampler=SortaGrad( | ||
| indices=range(len(train_dataset)), | ||
| batch_size=task_config.train_config.batch_size, | ||
| shuffle=shuffle, | ||
| drop_last=False, | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was in deepspeech_internal but could we avoid repetition by changing to:
if task_config.train_config.sortagrad:
batch_sampler = SortaGrad(...)
collate_fn = seq_to_seq_collate_fn_sorted
else:
batch_sampler = SequentialRandomSampler(...)
collate_fn = seq_to_seq_collate_fn
And then define the same train_loader in both cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See next comment for seq_to_seq_collate_fn vs seq_to_seq_collate_fn_sorted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did the exact same thing you have written, but unfortunately black was complaining during the pre-commit because he thinks the variable batch_sampler should be always a SortaGrad or a SequentialRandomSampler (he thinks there is a type inconsistency). I couldn't find any solution to this problem other than duplicate the code in the if-else statements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can define the types in the main body in this case:
from typing import Union
batch_sampler: Union[SortaGrad, SequentialRandomSampler]
if task_config.train_config.sortagrad:
batch_sampler = SortaGrad(...)
collate_fn = seq_to_seq_collate_fn_sorted
else:
batch_sampler = SequentialRandomSampler(...)
collate_fn = seq_to_seq_collate_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems it is working, but I couldn't avoid also the repetition of the batch sampler initialization
src/myrtlespeech/data/batch.py
Outdated
| # Sort the samples | ||
| samples = [ | ||
| (input, in_seq_len, target, target_seq_len) | ||
| for input, in_seq_len, target, target_seq_len in zip( | ||
| inputs, in_seq_lens, targets, target_seq_lens | ||
| ) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this function be deleted so that there is just seq_to_seq_collate_fn() with a bool argument for sorting as I think these lines are the only addition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
collate functions in pytorch get by default just the batch argument at runtime. I've added a commit where I did a small hack to make it work with 2 arguments (batch and the bool sort arguments) and it consists of using a lambda function when you create the DataLoader. Let me know if this solution looks clearer or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be possible to deal with the black errors - I have suggested a way that works I think.
edit: (I didn't mean to request more changes here - same ones are outstanding)
| batch_sampler=SortaGrad( | ||
| indices=range(len(train_dataset)), | ||
| batch_size=task_config.train_config.batch_size, | ||
| shuffle=shuffle, | ||
| drop_last=False, | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can define the types in the main body in this case:
from typing import Union
batch_sampler: Union[SortaGrad, SequentialRandomSampler]
if task_config.train_config.sortagrad:
batch_sampler = SortaGrad(...)
collate_fn = seq_to_seq_collate_fn_sorted
else:
batch_sampler = SequentialRandomSampler(...)
collate_fn = seq_to_seq_collate_fn
julianmack
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice - all looks good! Great.
| shuffle_batches_before_every_epoch: true; | ||
| sortagrad: true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Deep Speech 2 paper describes SortaGrad as:
Specifically, in the first training epoch we iterate through minibatches in the training set in increasing order of the length of the longest utterance in the minibatch. After the first epoch training reverts back to a random order over minibatches.
Having both a shuffle_batches_before_every_epoch and sortagrad option is not consistent with this. i.e. What does shuffle_batches_before_every_epoch: false; sortagrad: true mean?
A potential alternative is to have a shuffle_strategy-like field:
oneof shuffle_strategy {
Unshuffled unshuffled = 1;
Random random = 2;
SortaGrad sorta_grad = 3;
}
Unshuffled and Random are not good names but hopefully this gives an idea.
Thoughts on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically if sortagrad: true the shuffle_batches_before_every_epoch flag will be ignored in the first epoch, so I agree with the fact that having two separate flags could be a bit confusing for the final user
|
|
||
| package myrtlespeech.protos; | ||
|
|
||
| import "google/protobuf/wrappers.proto"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now an unused import :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed!
| indices=range(len(train_dataset)), | ||
| batch_size=task_config.train_config.batch_size, | ||
| shuffle=shuffle, | ||
| shuffle=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does sorting within a single batch matter? Why does SortaGrad require it and the other two cases not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I thought was that it could matter especially in the case you use big batch sizes. Moreover I saw that in deepspeech_internal the collate function also sorts every single batch, so I sticked with that implementation.
The value of the sort variable is only used inside the collate function and it is True only when we want to sort every single batch. It is set to false in the sequential_batches if case because I have interpreted it as "we want to go through every batch in a sequential way but we don't care about sorting single batches".
On the other hand I set it to true in the sorta_gradcase because I have interpreted it as "we want to go through every batch in a sequential way for the first epoch and we also care about sorting each single batch".
Let me know if my interpretation was wrong and something should be changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would change when the batch size increases?
The order of samples within a batch has no effect on the training process unless I'm overlooking something? i.e. the loss, gradient, etc will be the same.
If the above is true sort should just be fixed to one value for all cases to simplify the logic? True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense, I interpreted wrongly the sorting inside the collate function in deepspeech_internal. I have just removed the sort variable and changed consequently the collate function.
What are the init parameters for which I should add the documentation?
| @@ -29,10 +30,12 @@ message TrainConfig { | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: messages defined using CamelCase have variable names written using snake_case. This convention makes this line SortaGrad sorta_grad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it!
I actually thought Sortagrad was a single name instead of two, that's why I didn't use any underscore
src/myrtlespeech/data/sampler.py
Outdated
| random.shuffle(indices) | ||
| for index in indices: | ||
| yield self.batch_indices[index] | ||
| self._n_iterators += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be moved before the for loop to match the semantics described in the comment. For instance, consider the following:
indices = list(range(64))
srs = SequentialRandomSampler(
indices=indices,
batch_size=8,
shuffle=True,
sequential={0}
)
iter_1 = iter(srs)
iter_2 = iter(srs)
print(next(iter_1))
print(next(iter_2))This will output:
[0, 1, 2, 3, 4, 5, 6, 7]
[0, 1, 2, 3, 4, 5, 6, 7]
When it should be [0, 1, 2, 3, 4, 5, 6, 7] followed by a random batch of indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done it!
samgd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Nearly there!
Can the function/method/class init parameters be typed and documentation added?
| indices=range(len(train_dataset)), | ||
| batch_size=task_config.train_config.batch_size, | ||
| shuffle=shuffle, | ||
| shuffle=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would change when the batch size increases?
The order of samples within a batch has no effect on the training process unless I'm overlooking something? i.e. the loss, gradient, etc will be the same.
If the above is true sort should just be fixed to one value for all cases to simplify the logic? True?
samgd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding in and checking documentation for Sphinx and updating the deepspeech_internal tests file for myrtlespeech are the final hurdles 👍
src/myrtlespeech/data/sampler.py
Outdated
| indices, | ||
| batch_size, | ||
| shuffle, | ||
| drop_last=False, | ||
| n_iterators=0, | ||
| sequential=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters should be typed and then the types in the docstring can be removed.
| self._n_iterators = n_iterators | ||
| self._sequential = sequential or {} | ||
|
|
||
| def _batch_indices(self, indices, batch_size, drop_last): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Types?
src/myrtlespeech/data/sampler.py
Outdated
| """ | ||
|
|
||
| def __init__( | ||
| self, indices, batch_size, shuffle, drop_last=False, start_epoch=0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type the arguments and remove types from docstring.
src/myrtlespeech/data/sampler.py
Outdated
| other passes. See Deep Speech 2 paper for more information on this: | ||
| https://arxiv.org/abs/1512.02595 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restructured Text has syntax for creating a hyperlink, see here:
`Deep Speech 2 <https://arxiv.org/abs/1512.02595>`_
| @@ -0,0 +1,99 @@ | |||
| from myrtlespeech.data.sampler import SequentialRandomSampler | |||
| from myrtlespeech.data.sampler import SortaGrad | |||
|
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are copied over from deepspeech_internal, which is fine, but they should be updated to use Hypothesis.
Mainly, dataset_gen, n_batches, batch_size, full_last_batch etc can become parameters of each test that are generated by a search strategy. The benefit is that now failing tests search over a wider range and find the minimal failing test case. Currently the tests have arbitrary values chosen for each of the above.
src/myrtlespeech/data/sampler.py
Outdated
| def __init__(self, indices, batch_size, shuffle, drop_last=False): | ||
| The iterator used each time this iterable is iterated over will yield | ||
| batches either sequentially (i.e. in-order) or randomly (uniform without | ||
| replacement) from `batches`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just re-reading the docstring via Sphinx and realised this may be out of date: what is batches? Should this now be yield batches of indices either sequentially...without replacement).?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
src/myrtlespeech/data/sampler.py
Outdated
| sequential iterator is returned if the current count is in `sequential`. | ||
|
|
||
| Args: | ||
| indices: data with which batches are created. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: d -> D?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/myrtlespeech/data/sampler.py
Outdated
| drop_last: Optional[bool] = False, | ||
| n_iterators: Optional[int] = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional[T] means type either T or None is OK. It is equivalent to Union[T, None] - see docs.
Both drop_last and n_iterators have concrete values - i.e. are never None - so can be bool and int respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/myrtlespeech/data/sampler.py
Outdated
| self.batch_indices = self._batch_indices( | ||
| indices, batch_size, drop_last | ||
| ) | ||
| self._n_iterators: Optional[int] = n_iterators |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mypy should infer the type as int here after the update above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/myrtlespeech/data/sampler.py
Outdated
| indices, batch_size, drop_last | ||
| ) | ||
| self._n_iterators: Optional[int] = n_iterators | ||
| self._sequential: Union[Set, Dict] = sequential or {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding types actually caught a bug here: {} is a dictionary rather than a set and hence mypy was complaining the type should be Union[Set, Dict].
This should be self._sequential = sequential or set() and mypy should infer the type OK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
src/myrtlespeech/data/sampler.py
Outdated
|
|
||
| def __init__( | ||
| self, | ||
| indices: Union[range, List], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indices is actually an Iterable[int]. This is based on the API Python uses for for loops: https://treyhunner.com/2016/12/python-iterator-protocol-how-for-loops-work/
(or maybe it's Sequential[int] to enforce a maximum size?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it with Iterable
src/myrtlespeech/data/sampler.py
Outdated
|
|
||
| def __init__( | ||
| self, | ||
| indices: Union[range, List], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/data/test_sampler.py
Outdated
| sequential = set( | ||
| sorted(random.sample(range(max_sequential), n_sequential)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_sequential should actually be a set generated by Hypothesis rather than an integer that is then used internally to generate a set? This way Hypothesis can control n_sequential and shrink both values down when failing to the minimal possible test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have created a separate function that returns a SearchStrategy for the list of sequential epoch numbers, so max_sequential and n_sequential are not needed anymore in the test parameters. Let me know if this solution could be fine or if I need to change it
| @@ -0,0 +1,12 @@ | |||
| ============ | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: make these 1 longer than the title:
=========
sampler
=========
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| .. autoclass:: myrtlespeech.data.sampler.SequentialRandomSampler | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
|
|
||
| .. autoclass:: myrtlespeech.data.sampler.SortaGrad | ||
| :members: | ||
| :show-inheritance: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be auto-generated to reduce chance of forgetting to update it in the future?
.. automodule:: myrtlespeech.data.sampler
:members:
:show-inheritance:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added this at the beginning of the file
Added sortagrad train strategy + additional tests. The code has been mainly taken from the deepspeech_internal repo and slightly modified to make it work within myrtlespeech