You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Clean-up unnecessary warnings (including update to PyTorch 2.0) (#581)
Summary:
This PR is a collection of smaller fixes that will save us some deprecation issues in the future
## 1. Updating to PyTorch 2.0
**Key files: grad_sample/functorch.py, requirements.txt**
`functorch` has been a part of core PyTorch since 1.13.
Now they're going a step further and changing the API, while deprecating the old one.
There's a [guide](https://pytorch.org/docs/master/func.migrating.html) on how to migrate. TL;DR - `make_functional` will no longer be part of the API, with `torch.func.functional_call()` being (non drop-in) replacement.
They key difference for us is `make_functional()` creates a fresh copy of the module, while `functional_call()` uses existing module. As a matter of fact, we need the fresh copy (otherwise all the hooks start firing and you enter nested madness), so I've copy-pasted a [gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf) from the official guide on how to get a full replacement for `make_functional`.
## 2. New mechanism for gradient accumulation detection
**Key file: privacy_engine.py, grad_sample_module.py**
As [reported](https://discuss.pytorch.org/t/gan-raises-userwarning-using-a-non-full-backward-hook-when-the-forward-contains-multiple/175638/2) on the forum, clients are still getting "non-full backward hook" warning even when using `grad_sample_mode="ew"`. Naturally, `functorch` and `hooks` modes rely on backward hooks and can't be migrated to full hooks because [reasons](#328 (comment)). However, `ew` doesn't rely on hooks and it's unclear why the message should appear.
The reason, however, is simple. If the client is using poisson sampling we add an extra check to prohibit gradient accumulation (two poisson batches combined is not a poisson batch), and we do that by the means of backward hooks.
~In this case, backward hook serves a simple purpose and there shouldn't be any problems with migrating to the new method, however that involved changing the checking method. That's because `register_backward_hook` is called *after* hooks on submodule, but `register_full_backward_hook` is called before.~
Strikethrough solution didn't work, because hook order execution is weird for complex graphs, e.g. for GANs. For example, if your forward call looks like this:
```
Discriminator(Generator(x))
```
then top-level module hook will precede submodule's hooks for `Generator`, but not for `Discriminator`
As such, I've realised that gradient accumulation is not even supported in `ExpandedWeights`, so we don't have to worry about that. And the other two modes are both hooks-based, so we can just check the accumulation in the existing backward hook, no need for an extra hook. Deleted some code, profit.
## 3. Refactoring `wrap_collate_with_empty` to please pickle
Now here're two facts I didn't know before
1) You can't pickle a nested function, e.g. you can't do the following
```python
def foo():
def bar():
<...>
return bar
pickle.dump(foo(), ...)
```
2) Whether or not `multiprocessing` uses pickle is python- and platform- dependant.
This affects our tests when we test `DataLoader` with multiple workers. As such, our data loaders tests:
* Pass on CircleCI with python3.9
* Fail on my local machine with python3.9
* Pass on my local machine with python3.7
I'm not sure how cow common the issue is, but it's safer to just refactor `wrap_collate_with_empty` to avoid nested functions.
## 4. Fix benchmark tests
We don't really run `benchmarks/tests` on a regular basis, and some of them were broken since we've upgraded to PyTorch 1.13 (`API_CUTOFF_VERSION` doesn't exist anymore)
## 4. Fix flake8 config
Flake8 config no [longer support](https://flake8.pycqa.org/en/latest/user/configuration.html) inline comments, fix is due
Pull Request resolved: #581
Reviewed By: alexandresablayrolles
Differential Revision: D44749760
Pulled By: ffuuugor
fbshipit-source-id: cf225f4134c049da4ee2eef53e1af3ef54d090bf
0 commit comments