Skip to content

Commit fae7d19

Browse files
committed
fix unittests
1 parent 8ca5695 commit fae7d19

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/lightning/pytorch/callbacks/early_stopping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def state_dict(self) -> dict[str, Any]:
191191
"stopped_epoch": self.stopped_epoch,
192192
"best_score": self.best_score,
193193
"patience": self.patience,
194-
"stopping_reason": self.stopping_reason,
194+
"stopping_reason": self.stopping_reason.value,
195195
"stopping_reason_message": self.stopping_reason_message,
196196
}
197197

@@ -201,8 +201,8 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
201201
self.stopped_epoch = state_dict["stopped_epoch"]
202202
self.best_score = state_dict["best_score"]
203203
self.patience = state_dict["patience"]
204-
# For backward compatibility, set defaults if not present
205-
self.stopping_reason = state_dict.get("stopping_reason", EarlyStoppingReason.NOT_STOPPED)
204+
stopping_reason_value = state_dict.get("stopping_reason", EarlyStoppingReason.NOT_STOPPED.value)
205+
self.stopping_reason = EarlyStoppingReason(stopping_reason_value)
206206
self.stopping_reason_message = state_dict.get("stopping_reason_message")
207207

208208
def _should_skip_check(self, trainer: "pl.Trainer") -> bool:

0 commit comments

Comments
 (0)