@@ -67,9 +67,7 @@ class Callback
6767 typename MatType>
6868 static typename std::enable_if<
6969 callbacks::traits::HasBeginOptimizationSignature<
70- // Check for boolean return values anyway, for older ensmallen callbacks.
71- // (The return value is ignored.)
72- CallbackType, OptimizerType, FunctionType, MatType>::hasBool,
70+ CallbackType, OptimizerType, FunctionType, MatType>::value,
7371 void >::type
7472 BeginOptimizationFunction (CallbackType& callback,
7573 OptimizerType& optimizer,
@@ -85,25 +83,8 @@ class Callback
8583 typename FunctionType,
8684 typename MatType>
8785 static typename std::enable_if<
88- callbacks::traits::HasBeginOptimizationSignature<
89- CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
90- void >::type
91- BeginOptimizationFunction (CallbackType& callback,
92- OptimizerType& optimizer,
93- FunctionType& function,
94- MatType& coordinates)
95- {
96- const_cast <CallbackType&>(callback).BeginOptimization (optimizer, function,
97- coordinates);
98- }
99-
100- template <typename CallbackType,
101- typename OptimizerType,
102- typename FunctionType,
103- typename MatType>
104- static typename std::enable_if<
105- callbacks::traits::HasBeginOptimizationSignature<
106- CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
86+ !callbacks::traits::HasBeginOptimizationSignature<
87+ CallbackType, OptimizerType, FunctionType, MatType>::value,
10788 void >::type
10889 BeginOptimizationFunction (CallbackType& /* callback */ ,
10990 OptimizerType& /* optimizer */ ,
@@ -175,7 +156,8 @@ class Callback
175156 typename OptimizerType,
176157 typename FunctionType,
177158 typename MatType>
178- static typename std::enable_if<!callbacks::traits::HasEndOptimizationSignature<
159+ static typename std::enable_if<
160+ !callbacks::traits::HasEndOptimizationSignature<
179161 CallbackType, OptimizerType, FunctionType, MatType>::value,
180162 void >::type
181163 EndOptimizationFunction (CallbackType& /* callback */ ,
@@ -234,24 +216,42 @@ class Callback
234216 typename FunctionType,
235217 typename MatType>
236218 static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
237- CallbackType, OptimizerType, FunctionType, MatType>::value ,
219+ CallbackType, OptimizerType, FunctionType, MatType>::hasBool ,
238220 bool >::type
239221 EvaluateFunction (CallbackType& callback,
240222 OptimizerType& optimizer,
241223 FunctionType& function,
242224 const MatType& coordinates,
243225 const double objective)
244226 {
245- return ( const_cast <CallbackType&>(callback).Evaluate (
246- optimizer, function, coordinates, objective), false );
227+ return const_cast <CallbackType&>(callback).Evaluate (optimizer, function,
228+ coordinates, objective);
247229 }
248230
249231 template <typename CallbackType,
250232 typename OptimizerType,
251233 typename FunctionType,
252234 typename MatType>
253- static typename std::enable_if<!callbacks::traits::HasEvaluateSignature<
254- CallbackType, OptimizerType, FunctionType, MatType>::value,
235+ static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
236+ CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
237+ bool >::type
238+ EvaluateFunction (CallbackType& callback,
239+ OptimizerType& optimizer,
240+ FunctionType& function,
241+ const MatType& coordinates,
242+ const double objective)
243+ {
244+ const_cast <CallbackType&>(callback).Evaluate (optimizer, function,
245+ coordinates, objective);
246+ return false ;
247+ }
248+
249+ template <typename CallbackType,
250+ typename OptimizerType,
251+ typename FunctionType,
252+ typename MatType>
253+ static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
254+ CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
255255 bool >::type
256256 EvaluateFunction (CallbackType& /* callback */ ,
257257 OptimizerType& /* optimizer */ ,
@@ -304,7 +304,7 @@ class Callback
304304 typename MatType>
305305 static typename std::enable_if<
306306 callbacks::traits::HasEvaluateConstraintSignature<
307- CallbackType, OptimizerType, FunctionType, MatType>::value ,
307+ CallbackType, OptimizerType, FunctionType, MatType>::hasBool ,
308308 bool >::type
309309 EvaluateConstraintFunction (CallbackType& callback,
310310 OptimizerType& optimizer,
@@ -313,17 +313,37 @@ class Callback
313313 const size_t constraint,
314314 const double constraintValue)
315315 {
316- return ( const_cast <CallbackType&>(callback).EvaluateConstraint (
317- optimizer, function, coordinates, constraint, constraintValue), false ) ;
316+ return const_cast <CallbackType&>(callback).EvaluateConstraint (
317+ optimizer, function, coordinates, constraint, constraintValue);
318318 }
319319
320320 template <typename CallbackType,
321321 typename OptimizerType,
322322 typename FunctionType,
323323 typename MatType>
324324 static typename std::enable_if<
325- !callbacks::traits::HasEvaluateConstraintSignature<
326- CallbackType, OptimizerType, FunctionType, MatType>::value,
325+ callbacks::traits::HasEvaluateConstraintSignature<
326+ CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
327+ bool >::type
328+ EvaluateConstraintFunction (CallbackType& callback,
329+ OptimizerType& optimizer,
330+ FunctionType& function,
331+ const MatType& coordinates,
332+ const size_t constraint,
333+ const double constraintValue)
334+ {
335+ const_cast <CallbackType&>(callback).EvaluateConstraint (
336+ optimizer, function, coordinates, constraint, constraintValue);
337+ return false ;
338+ }
339+
340+ template <typename CallbackType,
341+ typename OptimizerType,
342+ typename FunctionType,
343+ typename MatType>
344+ static typename std::enable_if<
345+ callbacks::traits::HasEvaluateConstraintSignature<
346+ CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
327347 bool >::type
328348 EvaluateConstraintFunction (CallbackType& /* callback */ ,
329349 OptimizerType& /* optimizer */ ,
@@ -380,25 +400,44 @@ class Callback
380400 typename MatType,
381401 typename GradType>
382402 static typename std::enable_if<callbacks::traits::HasGradientSignature<
383- CallbackType, OptimizerType, FunctionType, MatType, GradType>::value ,
403+ CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasBool ,
384404 bool >::type
385405 GradientFunction (CallbackType& callback,
386406 OptimizerType& optimizer,
387407 FunctionType& function,
388408 const MatType& coordinates,
389409 GradType& gradient)
390410 {
391- return ( const_cast <CallbackType&>(callback).Gradient (
392- optimizer, function, coordinates, gradient), false );
411+ return const_cast <CallbackType&>(callback).Gradient (optimizer, function,
412+ coordinates, gradient);
393413 }
394414
395415 template <typename CallbackType,
396416 typename OptimizerType,
397417 typename FunctionType,
398418 typename MatType,
399419 typename GradType>
400- static typename std::enable_if<!callbacks::traits::HasGradientSignature<
401- CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
420+ static typename std::enable_if<callbacks::traits::HasGradientSignature<
421+ CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasVoid,
422+ bool >::type
423+ GradientFunction (CallbackType& callback,
424+ OptimizerType& optimizer,
425+ FunctionType& function,
426+ const MatType& coordinates,
427+ GradType& gradient)
428+ {
429+ const_cast <CallbackType&>(callback).Gradient (
430+ optimizer, function, coordinates, gradient);
431+ return false ;
432+ }
433+
434+ template <typename CallbackType,
435+ typename OptimizerType,
436+ typename FunctionType,
437+ typename MatType,
438+ typename GradType>
439+ static typename std::enable_if<callbacks::traits::HasGradientSignature<
440+ CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasNone,
402441 bool >::type
403442 GradientFunction (CallbackType& /* callback */ ,
404443 OptimizerType& /* optimizer */ ,
@@ -451,7 +490,27 @@ class Callback
451490 typename GradType>
452491 static typename std::enable_if<
453492 callbacks::traits::HasGradientConstraintSignature<
454- CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
493+ CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasBool,
494+ bool >::type
495+ GradientConstraintFunction (CallbackType& callback,
496+ OptimizerType& optimizer,
497+ FunctionType& function,
498+ const MatType& coordinates,
499+ const size_t constraint,
500+ GradType& gradient)
501+ {
502+ return const_cast <CallbackType&>(callback).GradientConstraint (optimizer,
503+ function, coordinates, constraint, gradient);
504+ }
505+
506+ template <typename CallbackType,
507+ typename OptimizerType,
508+ typename FunctionType,
509+ typename MatType,
510+ typename GradType>
511+ static typename std::enable_if<
512+ callbacks::traits::HasGradientConstraintSignature<
513+ CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasVoid,
455514 bool >::type
456515 GradientConstraintFunction (CallbackType& callback,
457516 OptimizerType& optimizer,
@@ -460,8 +519,9 @@ class Callback
460519 const size_t constraint,
461520 GradType& gradient)
462521 {
463- return (const_cast <CallbackType&>(callback).GradientConstraint (
464- optimizer, function, coordinates, constraint, gradient), false );
522+ const_cast <CallbackType&>(callback).GradientConstraint (
523+ optimizer, function, coordinates, constraint, gradient);
524+ return false ;
465525 }
466526
467527 template <typename CallbackType,
@@ -470,8 +530,8 @@ class Callback
470530 typename MatType,
471531 typename GradType>
472532 static typename std::enable_if<
473- ! callbacks::traits::HasGradientConstraintSignature<
474- CallbackType, OptimizerType, FunctionType, MatType, GradType>::value ,
533+ callbacks::traits::HasGradientConstraintSignature<
534+ CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasNone ,
475535 bool >::type
476536 GradientConstraintFunction (CallbackType& /* callback */ ,
477537 OptimizerType& /* optimizer */ ,
@@ -563,24 +623,42 @@ class Callback
563623 typename FunctionType,
564624 typename MatType>
565625 static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
566- CallbackType, OptimizerType, FunctionType, MatType>::value, bool >::type
626+ CallbackType, OptimizerType, FunctionType, MatType>::hasBool, bool >::type
627+ BeginEpochFunction (CallbackType& callback,
628+ OptimizerType& optimizer,
629+ FunctionType& function,
630+ const MatType& coordinates,
631+ const size_t epoch,
632+ const double objective)
633+ {
634+ return const_cast <CallbackType&>(callback).BeginEpoch (
635+ optimizer, function, coordinates, epoch, objective);
636+ }
637+
638+ template <typename CallbackType,
639+ typename OptimizerType,
640+ typename FunctionType,
641+ typename MatType>
642+ static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
643+ CallbackType, OptimizerType, FunctionType, MatType>::hasVoid, bool >::type
567644 BeginEpochFunction (CallbackType& callback,
568645 OptimizerType& optimizer,
569646 FunctionType& function,
570647 const MatType& coordinates,
571648 const size_t epoch,
572649 const double objective)
573650 {
574- return (const_cast <CallbackType&>(callback).BeginEpoch (
575- optimizer, function, coordinates, epoch, objective), false );
651+ const_cast <CallbackType&>(callback).BeginEpoch (
652+ optimizer, function, coordinates, epoch, objective);
653+ return false ;
576654 }
577655
578656 template <typename CallbackType,
579657 typename OptimizerType,
580658 typename FunctionType,
581659 typename MatType>
582- static typename std::enable_if<! callbacks::traits::HasBeginEpochSignature<
583- CallbackType, OptimizerType, FunctionType, MatType>::value , bool >::type
660+ static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
661+ CallbackType, OptimizerType, FunctionType, MatType>::hasNone , bool >::type
584662 BeginEpochFunction (CallbackType& /* callback */ ,
585663 OptimizerType& /* optimizer */ ,
586664 FunctionType& /* function */ ,
@@ -768,6 +846,32 @@ class Callback
768846 MatType& /* coordinates */ )
769847 { return false ; }
770848
849+ /* *
850+ * Iterate over the callbacks and invoke the StepTaken() callback if it
851+ * exists.
852+ *
853+ * @param optimizer The optimizer used to update the function.
854+ * @param function Function to optimize.
855+ * @param coordinates Starting point.
856+ * @param callbacks The callbacks container.
857+ */
858+ template <typename OptimizerType,
859+ typename FunctionType,
860+ typename MatType,
861+ typename ... CallbackTypes>
862+ static bool StepTaken (OptimizerType& optimizer,
863+ FunctionType& function,
864+ MatType& coordinates,
865+ CallbackTypes&... callbacks)
866+ {
867+ // This will return immediately once a callback returns true.
868+ bool result = false ;
869+ (void )std::initializer_list<bool >{ result =
870+ result || Callback::StepTakenFunction (callbacks, optimizer,
871+ function, coordinates)... };
872+ return result;
873+ }
874+
771875 /* *
772876 * Invoke the GenerationalStepTaken() callback if it exists.
773877 * Specialization for MultiObjective case.
@@ -819,7 +923,6 @@ class Callback
819923 {
820924 const_cast <CallbackType&>(callback).GenerationalStepTaken (
821925 optimizer, function, coordinates, objectives, frontIndices);
822-
823926 return false ;
824927 }
825928
@@ -841,32 +944,6 @@ class Callback
841944 IndicesType& /* frontIndices */ )
842945 { return false ; }
843946
844- /* *
845- * Iterate over the callbacks and invoke the StepTaken() callback if it
846- * exists.
847- *
848- * @param optimizer The optimizer used to update the function.
849- * @param function Function to optimize.
850- * @param coordinates Starting point.
851- * @param callbacks The callbacks container.
852- */
853- template <typename OptimizerType,
854- typename FunctionType,
855- typename MatType,
856- typename ... CallbackTypes>
857- static bool StepTaken (OptimizerType& optimizer,
858- FunctionType& function,
859- MatType& coordinates,
860- CallbackTypes&... callbacks)
861- {
862- // This will return immediately once a callback returns true.
863- bool result = false ;
864- (void )std::initializer_list<bool >{ result =
865- result || Callback::StepTakenFunction (callbacks, optimizer,
866- function, coordinates)... };
867- return result;
868- }
869-
870947 /* *
871948 * Iterate over the callbacks and invoke the GenerationalStepTaken() callback if it
872949 * exists.
0 commit comments