Skip to content

Commit a24c88a

Browse files
committed
ddp pickle
1 parent 710dbf5 commit a24c88a

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
5757
self.min_delta = min_delta
5858
self.wait = 0
5959
self.stopped_epoch = 0
60+
self.mode = mode
6061

6162
mode_dict = {
6263
'min': torch.lt,
@@ -67,9 +68,8 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
6768
if mode not in mode_dict:
6869
if self.verbose > 0:
6970
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
70-
mode = 'auto'
71+
self.mode = 'auto'
7172

72-
self.monitor_op = mode_dict[mode]
7373
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
7474

7575
def _validate_condition_metric(self, logs):
@@ -94,6 +94,15 @@ def _validate_condition_metric(self, logs):
9494

9595
return True
9696

97+
@property
98+
def monitor_op(self):
99+
mode_dict = {
100+
'min': torch.lt,
101+
'max': torch.gt,
102+
'auto': torch.gt if 'acc' in self.monitor else torch.lt
103+
}
104+
return mode_dict[self.mode]
105+
97106
def on_train_start(self, trainer, pl_module):
98107
# Allow instances to be re-used
99108
self.wait = 0

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def save_spawn_weights(self, model):
378378
:param model:
379379
:return:
380380
"""
381+
import pdb; pdb.set_trace()
381382
if self.proc_rank == 0:
382383
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
383384
self.save_checkpoint(path)

0 commit comments

Comments
 (0)