Skip to content

Commit 7ade3e9

Browse files
authored
Merge pull request #384 from rcurtin/callbacks-use-results
Fix some callbacks that ignored return values
2 parents cbc90d8 + d361b86 commit 7ade3e9

File tree

12 files changed

+360
-203
lines changed

12 files changed

+360
-203
lines changed

.appveyor.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@ environment:
44
BLAS_LIBRARY_DLL: "%APPVEYOR_BUILD_FOLDER%/OpenBLAS.0.2.14.1/lib/native/lib/x64/libopenblas.dll"
55

66
matrix:
7-
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015
8-
VSVER: Visual Studio 14 2015 Win64
9-
MSBUILD: C:\Program Files (x86)\MSBuild\14.0\bin\MSBuild.exe
10-
117
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2017
128
VSVER: Visual Studio 15 2017 Win64
139
MSBUILD: C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\MSBuild\15.0\Bin\MSBuild.exe
@@ -16,6 +12,10 @@ environment:
1612
VSVER: Visual Studio 16 2019
1713
MSBUILD: C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Current\Bin\MSBuild.exe
1814

15+
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2022
16+
VSVER: Visual Studio 17 2022
17+
MSBUILD: C:\Program Files\Microsoft Visual Studio\2022\Community\MSBuild\Current\Bin\MSBuild.exe
18+
1919
configuration: Release
2020

2121
install:

include/ensmallen_bits/callbacks/callbacks.hpp

Lines changed: 152 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ class Callback
6767
typename MatType>
6868
static typename std::enable_if<
6969
callbacks::traits::HasBeginOptimizationSignature<
70-
// Check for boolean return values anyway, for older ensmallen callbacks.
71-
// (The return value is ignored.)
72-
CallbackType, OptimizerType, FunctionType, MatType>::hasBool,
70+
CallbackType, OptimizerType, FunctionType, MatType>::value,
7371
void>::type
7472
BeginOptimizationFunction(CallbackType& callback,
7573
OptimizerType& optimizer,
@@ -85,25 +83,8 @@ class Callback
8583
typename FunctionType,
8684
typename MatType>
8785
static typename std::enable_if<
88-
callbacks::traits::HasBeginOptimizationSignature<
89-
CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
90-
void>::type
91-
BeginOptimizationFunction(CallbackType& callback,
92-
OptimizerType& optimizer,
93-
FunctionType& function,
94-
MatType& coordinates)
95-
{
96-
const_cast<CallbackType&>(callback).BeginOptimization(optimizer, function,
97-
coordinates);
98-
}
99-
100-
template<typename CallbackType,
101-
typename OptimizerType,
102-
typename FunctionType,
103-
typename MatType>
104-
static typename std::enable_if<
105-
callbacks::traits::HasBeginOptimizationSignature<
106-
CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
86+
!callbacks::traits::HasBeginOptimizationSignature<
87+
CallbackType, OptimizerType, FunctionType, MatType>::value,
10788
void>::type
10889
BeginOptimizationFunction(CallbackType& /* callback */,
10990
OptimizerType& /* optimizer */,
@@ -175,7 +156,8 @@ class Callback
175156
typename OptimizerType,
176157
typename FunctionType,
177158
typename MatType>
178-
static typename std::enable_if<!callbacks::traits::HasEndOptimizationSignature<
159+
static typename std::enable_if<
160+
!callbacks::traits::HasEndOptimizationSignature<
179161
CallbackType, OptimizerType, FunctionType, MatType>::value,
180162
void>::type
181163
EndOptimizationFunction(CallbackType& /* callback */,
@@ -234,24 +216,42 @@ class Callback
234216
typename FunctionType,
235217
typename MatType>
236218
static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
237-
CallbackType, OptimizerType, FunctionType, MatType>::value,
219+
CallbackType, OptimizerType, FunctionType, MatType>::hasBool,
238220
bool>::type
239221
EvaluateFunction(CallbackType& callback,
240222
OptimizerType& optimizer,
241223
FunctionType& function,
242224
const MatType& coordinates,
243225
const double objective)
244226
{
245-
return (const_cast<CallbackType&>(callback).Evaluate(
246-
optimizer, function, coordinates, objective), false);
227+
return const_cast<CallbackType&>(callback).Evaluate(optimizer, function,
228+
coordinates, objective);
247229
}
248230

249231
template<typename CallbackType,
250232
typename OptimizerType,
251233
typename FunctionType,
252234
typename MatType>
253-
static typename std::enable_if<!callbacks::traits::HasEvaluateSignature<
254-
CallbackType, OptimizerType, FunctionType, MatType>::value,
235+
static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
236+
CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
237+
bool>::type
238+
EvaluateFunction(CallbackType& callback,
239+
OptimizerType& optimizer,
240+
FunctionType& function,
241+
const MatType& coordinates,
242+
const double objective)
243+
{
244+
const_cast<CallbackType&>(callback).Evaluate(optimizer, function,
245+
coordinates, objective);
246+
return false;
247+
}
248+
249+
template<typename CallbackType,
250+
typename OptimizerType,
251+
typename FunctionType,
252+
typename MatType>
253+
static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
254+
CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
255255
bool>::type
256256
EvaluateFunction(CallbackType& /* callback */,
257257
OptimizerType& /* optimizer */,
@@ -304,7 +304,7 @@ class Callback
304304
typename MatType>
305305
static typename std::enable_if<
306306
callbacks::traits::HasEvaluateConstraintSignature<
307-
CallbackType, OptimizerType, FunctionType, MatType>::value,
307+
CallbackType, OptimizerType, FunctionType, MatType>::hasBool,
308308
bool>::type
309309
EvaluateConstraintFunction(CallbackType& callback,
310310
OptimizerType& optimizer,
@@ -313,17 +313,37 @@ class Callback
313313
const size_t constraint,
314314
const double constraintValue)
315315
{
316-
return (const_cast<CallbackType&>(callback).EvaluateConstraint(
317-
optimizer, function, coordinates, constraint, constraintValue), false);
316+
return const_cast<CallbackType&>(callback).EvaluateConstraint(
317+
optimizer, function, coordinates, constraint, constraintValue);
318318
}
319319

320320
template<typename CallbackType,
321321
typename OptimizerType,
322322
typename FunctionType,
323323
typename MatType>
324324
static typename std::enable_if<
325-
!callbacks::traits::HasEvaluateConstraintSignature<
326-
CallbackType, OptimizerType, FunctionType, MatType>::value,
325+
callbacks::traits::HasEvaluateConstraintSignature<
326+
CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
327+
bool>::type
328+
EvaluateConstraintFunction(CallbackType& callback,
329+
OptimizerType& optimizer,
330+
FunctionType& function,
331+
const MatType& coordinates,
332+
const size_t constraint,
333+
const double constraintValue)
334+
{
335+
const_cast<CallbackType&>(callback).EvaluateConstraint(
336+
optimizer, function, coordinates, constraint, constraintValue);
337+
return false;
338+
}
339+
340+
template<typename CallbackType,
341+
typename OptimizerType,
342+
typename FunctionType,
343+
typename MatType>
344+
static typename std::enable_if<
345+
callbacks::traits::HasEvaluateConstraintSignature<
346+
CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
327347
bool>::type
328348
EvaluateConstraintFunction(CallbackType& /* callback */,
329349
OptimizerType& /* optimizer */,
@@ -380,25 +400,44 @@ class Callback
380400
typename MatType,
381401
typename GradType>
382402
static typename std::enable_if<callbacks::traits::HasGradientSignature<
383-
CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
403+
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasBool,
384404
bool>::type
385405
GradientFunction(CallbackType& callback,
386406
OptimizerType& optimizer,
387407
FunctionType& function,
388408
const MatType& coordinates,
389409
GradType& gradient)
390410
{
391-
return (const_cast<CallbackType&>(callback).Gradient(
392-
optimizer, function, coordinates, gradient), false);
411+
return const_cast<CallbackType&>(callback).Gradient(optimizer, function,
412+
coordinates, gradient);
393413
}
394414

395415
template<typename CallbackType,
396416
typename OptimizerType,
397417
typename FunctionType,
398418
typename MatType,
399419
typename GradType>
400-
static typename std::enable_if<!callbacks::traits::HasGradientSignature<
401-
CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
420+
static typename std::enable_if<callbacks::traits::HasGradientSignature<
421+
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasVoid,
422+
bool>::type
423+
GradientFunction(CallbackType& callback,
424+
OptimizerType& optimizer,
425+
FunctionType& function,
426+
const MatType& coordinates,
427+
GradType& gradient)
428+
{
429+
const_cast<CallbackType&>(callback).Gradient(
430+
optimizer, function, coordinates, gradient);
431+
return false;
432+
}
433+
434+
template<typename CallbackType,
435+
typename OptimizerType,
436+
typename FunctionType,
437+
typename MatType,
438+
typename GradType>
439+
static typename std::enable_if<callbacks::traits::HasGradientSignature<
440+
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasNone,
402441
bool>::type
403442
GradientFunction(CallbackType& /* callback */,
404443
OptimizerType& /* optimizer */,
@@ -451,7 +490,27 @@ class Callback
451490
typename GradType>
452491
static typename std::enable_if<
453492
callbacks::traits::HasGradientConstraintSignature<
454-
CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
493+
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasBool,
494+
bool>::type
495+
GradientConstraintFunction(CallbackType& callback,
496+
OptimizerType& optimizer,
497+
FunctionType& function,
498+
const MatType& coordinates,
499+
const size_t constraint,
500+
GradType& gradient)
501+
{
502+
return const_cast<CallbackType&>(callback).GradientConstraint(optimizer,
503+
function, coordinates, constraint, gradient);
504+
}
505+
506+
template<typename CallbackType,
507+
typename OptimizerType,
508+
typename FunctionType,
509+
typename MatType,
510+
typename GradType>
511+
static typename std::enable_if<
512+
callbacks::traits::HasGradientConstraintSignature<
513+
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasVoid,
455514
bool>::type
456515
GradientConstraintFunction(CallbackType& callback,
457516
OptimizerType& optimizer,
@@ -460,8 +519,9 @@ class Callback
460519
const size_t constraint,
461520
GradType& gradient)
462521
{
463-
return (const_cast<CallbackType&>(callback).GradientConstraint(
464-
optimizer, function, coordinates, constraint, gradient), false);
522+
const_cast<CallbackType&>(callback).GradientConstraint(
523+
optimizer, function, coordinates, constraint, gradient);
524+
return false;
465525
}
466526

467527
template<typename CallbackType,
@@ -470,8 +530,8 @@ class Callback
470530
typename MatType,
471531
typename GradType>
472532
static typename std::enable_if<
473-
!callbacks::traits::HasGradientConstraintSignature<
474-
CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
533+
callbacks::traits::HasGradientConstraintSignature<
534+
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasNone,
475535
bool>::type
476536
GradientConstraintFunction(CallbackType& /* callback */,
477537
OptimizerType& /* optimizer */,
@@ -563,24 +623,42 @@ class Callback
563623
typename FunctionType,
564624
typename MatType>
565625
static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
566-
CallbackType, OptimizerType, FunctionType, MatType>::value, bool>::type
626+
CallbackType, OptimizerType, FunctionType, MatType>::hasBool, bool>::type
627+
BeginEpochFunction(CallbackType& callback,
628+
OptimizerType& optimizer,
629+
FunctionType& function,
630+
const MatType& coordinates,
631+
const size_t epoch,
632+
const double objective)
633+
{
634+
return const_cast<CallbackType&>(callback).BeginEpoch(
635+
optimizer, function, coordinates, epoch, objective);
636+
}
637+
638+
template<typename CallbackType,
639+
typename OptimizerType,
640+
typename FunctionType,
641+
typename MatType>
642+
static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
643+
CallbackType, OptimizerType, FunctionType, MatType>::hasVoid, bool>::type
567644
BeginEpochFunction(CallbackType& callback,
568645
OptimizerType& optimizer,
569646
FunctionType& function,
570647
const MatType& coordinates,
571648
const size_t epoch,
572649
const double objective)
573650
{
574-
return (const_cast<CallbackType&>(callback).BeginEpoch(
575-
optimizer, function, coordinates, epoch, objective), false);
651+
const_cast<CallbackType&>(callback).BeginEpoch(
652+
optimizer, function, coordinates, epoch, objective);
653+
return false;
576654
}
577655

578656
template<typename CallbackType,
579657
typename OptimizerType,
580658
typename FunctionType,
581659
typename MatType>
582-
static typename std::enable_if<!callbacks::traits::HasBeginEpochSignature<
583-
CallbackType, OptimizerType, FunctionType, MatType>::value, bool>::type
660+
static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
661+
CallbackType, OptimizerType, FunctionType, MatType>::hasNone, bool>::type
584662
BeginEpochFunction(CallbackType& /* callback */,
585663
OptimizerType& /* optimizer */,
586664
FunctionType& /* function */,
@@ -768,6 +846,32 @@ class Callback
768846
MatType& /* coordinates */)
769847
{ return false; }
770848

849+
/**
850+
* Iterate over the callbacks and invoke the StepTaken() callback if it
851+
* exists.
852+
*
853+
* @param optimizer The optimizer used to update the function.
854+
* @param function Function to optimize.
855+
* @param coordinates Starting point.
856+
* @param callbacks The callbacks container.
857+
*/
858+
template<typename OptimizerType,
859+
typename FunctionType,
860+
typename MatType,
861+
typename... CallbackTypes>
862+
static bool StepTaken(OptimizerType& optimizer,
863+
FunctionType& function,
864+
MatType& coordinates,
865+
CallbackTypes&... callbacks)
866+
{
867+
// This will return immediately once a callback returns true.
868+
bool result = false;
869+
(void)std::initializer_list<bool>{ result =
870+
result || Callback::StepTakenFunction(callbacks, optimizer,
871+
function, coordinates)... };
872+
return result;
873+
}
874+
771875
/**
772876
* Invoke the GenerationalStepTaken() callback if it exists.
773877
* Specialization for MultiObjective case.
@@ -819,7 +923,6 @@ class Callback
819923
{
820924
const_cast<CallbackType&>(callback).GenerationalStepTaken(
821925
optimizer, function, coordinates, objectives, frontIndices);
822-
823926
return false;
824927
}
825928

@@ -841,32 +944,6 @@ class Callback
841944
IndicesType& /* frontIndices */)
842945
{ return false; }
843946

844-
/**
845-
* Iterate over the callbacks and invoke the StepTaken() callback if it
846-
* exists.
847-
*
848-
* @param optimizer The optimizer used to update the function.
849-
* @param function Function to optimize.
850-
* @param coordinates Starting point.
851-
* @param callbacks The callbacks container.
852-
*/
853-
template<typename OptimizerType,
854-
typename FunctionType,
855-
typename MatType,
856-
typename... CallbackTypes>
857-
static bool StepTaken(OptimizerType& optimizer,
858-
FunctionType& function,
859-
MatType& coordinates,
860-
CallbackTypes&... callbacks)
861-
{
862-
// This will return immediately once a callback returns true.
863-
bool result = false;
864-
(void)std::initializer_list<bool>{ result =
865-
result || Callback::StepTakenFunction(callbacks, optimizer,
866-
function, coordinates)... };
867-
return result;
868-
}
869-
870947
/**
871948
* Iterate over the callbacks and invoke the GenerationalStepTaken() callback if it
872949
* exists.

0 commit comments

Comments
 (0)