11/* *
22 * @file early_stop_at_min_loss.hpp
33 * @author Marcus Edel
4+ * @author Omar Shrit
45 *
56 * Implementation of the early stop at minimum loss callback function.
67 *
1213#ifndef ENSMALLEN_CALLBACKS_EARLY_STOP_AT_MIN_LOSS_HPP
1314#define ENSMALLEN_CALLBACKS_EARLY_STOP_AT_MIN_LOSS_HPP
1415
16+ #include < functional>
17+
1518namespace ens {
1619
1720/* *
1821 * Early stopping to terminate the optimization process early if the loss stops
1922 * decreasing.
2023 */
21- class EarlyStopAtMinLoss
24+ template <typename MatType = arma::mat>
25+ class EarlyStopAtMinLossType
2226{
2327 public:
2428 /* *
@@ -28,12 +32,33 @@ class EarlyStopAtMinLoss
2832 * @param patienceIn The number of epochs to wait after the minimum loss has
2933 * been reached or no improvement has been made (Default: 10).
3034 */
31- EarlyStopAtMinLoss (const size_t patienceIn = 10 ) :
32- patience (patienceIn),
35+ EarlyStopAtMinLossType<MatType>(const size_t patienceIn = 10 ) :
36+ callbackUsed (false ),
37+ patience (patienceIn),
3338 bestObjective (std::numeric_limits<double >::max()),
3439 steps (0 )
3540 { /* Nothing to do here */ }
3641
42+ /* *
43+ * Set up the early stop at min loss class, which keeps track of the minimum
44+ * loss and stops the optimization process if the loss stops decreasing.
45+ *
46+ * @param func, callback to return immediate loss evaluated by the function
47+ * @param patienceIn The number of epochs to wait after the minimum loss has
48+ * been reached or no improvement has been made (Default: 10).
49+ */
50+ EarlyStopAtMinLossType<MatType>(
51+ std::function<double (const MatType&)> func,
52+ const size_t patienceIn = 10 )
53+ : callbackUsed(true ),
54+ patience (patienceIn),
55+ bestObjective(std::numeric_limits<double >::max()),
56+ steps(0 ),
57+ localFunc(func)
58+ {
59+ // Nothing to do here
60+ }
61+
3762 /* *
3863 * Callback function called at the end of a pass over the data.
3964 *
@@ -43,13 +68,18 @@ class EarlyStopAtMinLoss
4368 * @param epoch The index of the current epoch.
4469 * @param objective Objective value of the current point.
4570 */
46- template <typename OptimizerType, typename FunctionType, typename MatType >
71+ template <typename OptimizerType, typename FunctionType>
4772 bool EndEpoch (OptimizerType& /* optimizer */ ,
4873 FunctionType& /* function */ ,
49- const MatType& /* coordinates */ ,
74+ const MatType& coordinates,
5075 const size_t /* epoch */ ,
51- const double objective)
76+ double objective)
5277 {
78+ if (callbackUsed)
79+ {
80+ objective = localFunc (coordinates);
81+ }
82+
5383 if (objective < bestObjective)
5484 {
5585 steps = 0 ;
@@ -68,6 +98,9 @@ class EarlyStopAtMinLoss
6898 }
6999
70100 private:
101+ // ! False if the first constructor is called, true if the user passed a lambda.
102+ bool callbackUsed;
103+
71104 // ! The number of epochs to wait before terminating the optimization process.
72105 size_t patience;
73106
@@ -76,8 +109,19 @@ class EarlyStopAtMinLoss
76109
77110 // ! Locally-stored number of steps since the loss improved.
78111 size_t steps;
112+
113+ // ! Function to call at the end of the epoch.
114+ std::function<double (const MatType&)> localFunc;
79115};
80116
117+ /*
118+ * Note that the using definition is temporary, this definition should
119+ * be removed when releasing ensmallen 3.0
120+ * The renaming of the class is only to avoid a major version bump
121+ * because if the template type added to this class
122+ */
123+ using EarlyStopAtMinLoss = EarlyStopAtMinLossType<arma::mat>;
124+
81125} // namespace ens
82126
83127#endif
0 commit comments