Skip to content

Commit 662b6d6

Browse files
awaelchlilexierule
authored andcommitted
Fix(Early Stopping): move best score to device (#7959)
1 parent 591d617 commit 662b6d6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _run_early_stopping_check(self, trainer) -> None:
190190
# when in dev debugging
191191
trainer.dev_debugger.track_early_stopping_history(self, current)
192192

193-
should_stop, reason = self._evalute_stopping_criteria(current)
193+
should_stop, reason = self._evalute_stopping_criteria(current, trainer)
194194

195195
# stop every ddp process if any world process decides to stop
196196
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
@@ -200,7 +200,7 @@ def _run_early_stopping_check(self, trainer) -> None:
200200
if reason and self.verbose:
201201
self._log_info(trainer, reason)
202202

203-
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
203+
def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]:
204204
should_stop = False
205205
reason = None
206206
if self.check_finite and not torch.isfinite(current):
@@ -223,7 +223,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
223223
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
224224
" Signaling Trainer to stop."
225225
)
226-
elif self.monitor_op(current - self.min_delta, self.best_score):
226+
elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)):
227227
should_stop = False
228228
reason = self._improvement_message(current)
229229
self.best_score = current

0 commit comments

Comments
 (0)