@@ -190,7 +190,7 @@ def _run_early_stopping_check(self, trainer) -> None:
190190 # when in dev debugging
191191 trainer .dev_debugger .track_early_stopping_history (self , current )
192192
193- should_stop , reason = self ._evalute_stopping_criteria (current )
193+ should_stop , reason = self ._evalute_stopping_criteria (current , trainer )
194194
195195 # stop every ddp process if any world process decides to stop
196196 should_stop = trainer .training_type_plugin .reduce_boolean_decision (should_stop )
@@ -200,7 +200,7 @@ def _run_early_stopping_check(self, trainer) -> None:
200200 if reason and self .verbose :
201201 self ._log_info (trainer , reason )
202202
203- def _evalute_stopping_criteria (self , current : torch .Tensor ) -> Tuple [bool , str ]:
203+ def _evalute_stopping_criteria (self , current : torch .Tensor , trainer : 'pl.Trainer' ) -> Tuple [bool , str ]:
204204 should_stop = False
205205 reason = None
206206 if self .check_finite and not torch .isfinite (current ):
@@ -223,7 +223,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
223223 f" { self .monitor } = { current } { self .order_dict [self .mode ]} { self .divergence_threshold } ."
224224 " Signaling Trainer to stop."
225225 )
226- elif self .monitor_op (current - self .min_delta , self .best_score ):
226+ elif self .monitor_op (current - self .min_delta , self .best_score . to ( trainer . lightning_module . device ) ):
227227 should_stop = False
228228 reason = self ._improvement_message (current )
229229 self .best_score = current
0 commit comments