Skip to content

Commit 3e38005

Browse files
Ddp2 fix (#448)
* added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * allow ddp and apex to be configured * allow ddp and apex to be configured * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * added eval and train for redundancy * added eval and train for redundancy * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * added training_end * allow ddp and apex to be configured * allow ddp and apex to be configured * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * bananas * added eval and train for redundancy * added eval and train for redundancy
1 parent 8fbaccd commit 3e38005

File tree

10 files changed

+284
-63
lines changed

10 files changed

+284
-63
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ Lightning also adds a text column with all the hyperparameters for this experime
294294

295295
#### Distributed training
296296

297+
- [Implement Your Own Distributed (DDP) training](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/#init_ddp_connection)
297298
- [16-bit mixed precision](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#16-bit-mixed-precision)
298299
- [Multi-GPU](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-GPU)
299300
- [Multi-node](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-node)

docs/LightningModule/RequiredTrainerInterface.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Otherwise, to Define a Lightning Module, implement the following methods:
1515

1616
**Optional**:
1717

18+
- [training_end](RequiredTrainerInterface.md#training_end)
1819
- [validation_step](RequiredTrainerInterface.md#validation_step)
1920
- [validation_end](RequiredTrainerInterface.md#validation_end)
2021
- [test_step](RequiredTrainerInterface.md#test_step)
@@ -178,6 +179,89 @@ def training_step(self, batch, batch_nb, hiddens):
178179
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to
179180
break out of the current training epoch early.
180181

182+
---
183+
### training_end
184+
185+
``` {.python}
186+
def training_end(self, train_step_outputs)
187+
```
188+
In certain cases (dp, ddp2), you might want to use all outputs of every process to do something.
189+
For instance, if using negative samples, you could run a batch via dp and use ALL the outputs
190+
for a single softmax across the full batch (ie: the denominator would use the full batch).
191+
192+
In this case you should define training_end to perform those calculations.
193+
194+
195+
**Params**
196+
197+
| Param | description |
198+
|---|---|
199+
| outputs | What you return in training_step.
200+
201+
**Return**
202+
203+
Dictionary or OrderedDict
204+
205+
| key | value | is required |
206+
|---|---|---|
207+
| loss | tensor scalar | Y |
208+
| progress_bar | Dict for progress bar display. Must have only tensors | N |
209+
| log | Dict of metrics to add to logger. Must have only tensors (no images, etc) | N |
210+
211+
212+
**Example**
213+
214+
``` {.python}
215+
# WITHOUT training_end
216+
# if used in DP or DDP2, this batch is 1/nb_gpus large
217+
def training_step(self, batch, batch_nb):
218+
# batch is 1/nb_gpus big
219+
x, y = batch
220+
221+
out = self.forward(x)
222+
loss = self.softmax(out)
223+
loss = nce_loss(loss)
224+
return {'loss': loss}
225+
226+
# --------------
227+
# with training_end to do softmax over the full batch
228+
def training_step(self, batch, batch_nb):
229+
# batch is 1/nb_gpus big
230+
x, y = batch
231+
232+
out = self.forward(x)
233+
return {'out': out}
234+
235+
def training_end(self, outputs):
236+
# this out is now the full size of the batch
237+
out = outputs['out']
238+
239+
# this softmax now uses the full batch size
240+
loss = self.softmax(out)
241+
loss = nce_loss(loss)
242+
return {'loss': loss}
243+
```
244+
245+
If you define multiple optimizers, this step will also be called with an additional ```optimizer_idx``` param.
246+
``` {.python}
247+
# Multiple optimizers (ie: GANs)
248+
def training_step(self, batch, batch_nb, optimizer_idx):
249+
if optimizer_idx == 0:
250+
# do training_step with encoder
251+
if optimizer_idx == 1:
252+
# do training_step with decoder
253+
```
254+
255+
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
256+
``` {.python}
257+
# Truncated back-propagation through time
258+
def training_step(self, batch, batch_nb, hiddens):
259+
# hiddens are the hiddens from the previous truncated backprop step
260+
```
261+
262+
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to
263+
break out of the current training epoch early.
264+
181265
---
182266
### train_dataloader
183267

docs/Trainer/hooks.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,92 @@ def tbptt_split_batch(self, batch, split_size):
175175

176176
return splits
177177
```
178+
179+
---
180+
#### configure_apex
181+
Overwrite to define your own Apex implementation init.
182+
183+
```python
184+
def configure_apex(self, amp, model, optimizers, amp_level):
185+
"""
186+
Override to init AMP your own way
187+
Must return a model and list of optimizers
188+
:param amp:
189+
:param model:
190+
:param optimizers:
191+
:param amp_level:
192+
:return: Apex wrapped model and optimizers
193+
"""
194+
model, optimizers = amp.initialize(
195+
model, optimizers, opt_level=amp_level,
196+
)
197+
198+
return model, optimizers
199+
```
200+
201+
---
202+
#### configure_ddp
203+
Overwrite to define your own DDP implementation init.
204+
The only requirement is that:
205+
1. On a validation batch the call goes to model.validation_step.
206+
2. On a training batch the call goes to model.training_step.
207+
3. On a testing batch, the call goes to model.test_step
208+
209+
```python
210+
def configure_ddp(self, model, device_ids):
211+
"""
212+
Override to init DDP in a different way or use your own wrapper.
213+
Must return model.
214+
:param model:
215+
:param device_ids:
216+
:return: DDP wrapped model
217+
"""
218+
# Lightning DDP simply routes to test_step, val_step, etc...
219+
model = LightningDistributedDataParallel(
220+
model,
221+
device_ids=device_ids,
222+
find_unused_parameters=True
223+
)
224+
return model
225+
```
226+
227+
---
228+
#### init_ddp_connection
229+
Override to init DDP in your own way.
230+
231+
```python
232+
def init_ddp_connection(self):
233+
"""
234+
Connect all procs in the world using the env:// init
235+
Use the first node as the root address
236+
"""
237+
238+
# use slurm job id for the port number
239+
# guarantees unique ports across jobs from same grid search
240+
try:
241+
# use the last 4 numbers in the job id as the id
242+
default_port = os.environ['SLURM_JOB_ID']
243+
default_port = default_port[-4:]
244+
245+
# all ports should be in the 10k+ range
246+
default_port = int(default_port) + 15000
247+
248+
except Exception as e:
249+
default_port = 12910
250+
251+
# if user gave a port number, use that one instead
252+
try:
253+
default_port = os.environ['MASTER_PORT']
254+
except Exception:
255+
os.environ['MASTER_PORT'] = str(default_port)
256+
257+
# figure out the root node addr
258+
try:
259+
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
260+
except Exception:
261+
root_node = '127.0.0.2'
262+
263+
root_node = self.trainer.resolve_root_node_address(root_node)
264+
os.environ['MASTER_ADDR'] = root_node
265+
dist.init_process_group('nccl', rank=self.proc_rank, world_size=self.world_size)
266+
```

docs/Trainer/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ But of course the fun is in all the advanced things it can do:
4242

4343
**Distributed training**
4444

45+
- [Implement Your Own Distributed (DDP) training](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/#init_ddp_connection)
4546
- [16-bit mixed precision](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#16-bit-mixed-precision)
4647
- [Multi-GPU](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-GPU)
4748
- [Multi-node](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-node)

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Notice a few things about this flow:
9999

100100
###### Distributed training
101101

102+
- [Implement Your Own Distributed (DDP) training](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/#init_ddp_connection)
102103
- [16-bit mixed precision](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#16-bit-mixed-precision)
103104
- [Multi-GPU](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-GPU)
104105
- [Multi-node](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-node)

pytorch_lightning/root_module/root_module.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
12
import warnings
23
import collections
34
from argparse import Namespace
45

56
import torch
7+
import torch.distributed as dist
68

79
from pytorch_lightning.root_module.decorators import data_loader
810
from pytorch_lightning.root_module.grads import GradInformation
@@ -11,6 +13,7 @@
1113
from pytorch_lightning.root_module.model_saving import ModelIO
1214
from pytorch_lightning.trainer.trainer_io import load_hparams_from_tags_csv
1315
import logging
16+
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel
1417

1518

1619
class LightningModule(GradInformation, ModelIO, ModelHooks):
@@ -48,10 +51,19 @@ def training_step(self, *args, **kwargs):
4851
return loss, dict with metrics for tqdm
4952
:param called with batch, batch_nb
5053
additional: optimizer_i if multiple optimizers used
51-
:return:
54+
:return: dict with loss key and optional log, progress keys
55+
if implementing training_step, return whatever you need in that step
5256
"""
5357
raise NotImplementedError
5458

59+
def training_end(self, *args, **kwargs):
60+
"""
61+
return loss, dict with metrics for tqdm
62+
:param called with outputs of training_step
63+
:return: dict with loss key and optional log, progress keys
64+
"""
65+
pass
66+
5567
def validation_step(self, *args, **kwargs):
5668
"""
5769
return whatever outputs will need to be aggregated in validation_end
@@ -90,6 +102,72 @@ def test_end(self, outputs):
90102
"""
91103
pass
92104

105+
def configure_ddp(self, model, device_ids):
106+
"""
107+
Override to init DDP in a different way or use your own wrapper.
108+
Must return model.
109+
:param model:
110+
:param device_ids:
111+
:return: DDP wrapped model
112+
"""
113+
model = LightningDistributedDataParallel(
114+
model,
115+
device_ids=device_ids,
116+
find_unused_parameters=True
117+
)
118+
return model
119+
120+
def init_ddp_connection(self, proc_rank, world_size):
121+
"""
122+
Connect all procs in the world using the env:// init
123+
Use the first node as the root address
124+
"""
125+
126+
# use slurm job id for the port number
127+
# guarantees unique ports across jobs from same grid search
128+
try:
129+
# use the last 4 numbers in the job id as the id
130+
default_port = os.environ['SLURM_JOB_ID']
131+
default_port = default_port[-4:]
132+
133+
# all ports should be in the 10k+ range
134+
default_port = int(default_port) + 15000
135+
136+
except Exception as e:
137+
default_port = 12910
138+
139+
# if user gave a port number, use that one instead
140+
try:
141+
default_port = os.environ['MASTER_PORT']
142+
except Exception:
143+
os.environ['MASTER_PORT'] = str(default_port)
144+
145+
# figure out the root node addr
146+
try:
147+
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
148+
except Exception:
149+
root_node = '127.0.0.2'
150+
151+
root_node = self.trainer.resolve_root_node_address(root_node)
152+
os.environ['MASTER_ADDR'] = root_node
153+
dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
154+
155+
def configure_apex(self, amp, model, optimizers, amp_level):
156+
"""
157+
Override to init AMP your own way
158+
Must return a model and list of optimizers
159+
:param amp:
160+
:param model:
161+
:param optimizers:
162+
:param amp_level:
163+
:return: Apex wrapped model and optimizers
164+
"""
165+
model, optimizers = amp.initialize(
166+
model, optimizers, opt_level=amp_level,
167+
)
168+
169+
return model, optimizers
170+
93171
def configure_optimizers(self):
94172
"""
95173
Return a list of optimizers and a list of schedulers (could be empty)

0 commit comments

Comments
 (0)