Skip to content

Commit 9fa90f1

Browse files
authored
Merge Bandicoot tests with other tests (#438)
* Normalize step size to the batch size used. * Update history. * Don't use if constexpr. * Fix all step sizes in tests to account for new behavior. * Retune DemonSGD tests to reduce failure probability. * Merge Bandicoot tests into existing TEMPLATE_TEST_CASEs. * Fix merge artifact. * Fix compilation errors. * Fix MOEAD test to be more robust. * Oops, fix missing colon. * Fix compilation issues. * Remove arma:: for randu call. * Try to fix AppVeyor build. * Bump version number to fix mlpack integration build. * Adapt FBS/FISTA/FASTA to bandicoot tests. * Keep bandicoot implementation separate. * Print which tests take a long time so I can prune down the runtime. * Remove some tests that take a long long time. * Filter out some more long-running GPU tests.
1 parent 37b057d commit 9fa90f1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+220
-1301
lines changed

.github/workflows/bandicoot-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ jobs:
5151
- name: Test ensmallen
5252
run: |
5353
cd build/
54-
./ensmallen_tests
54+
./ensmallen_tests -d yes

include/ensmallen_bits/fasta/fasta.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class FASTA
102102
*/
103103
template<typename FunctionType, typename MatType, typename GradType,
104104
typename... CallbackTypes>
105-
typename std::enable_if<IsArmaType<GradType>::value,
105+
typename std::enable_if<IsMatrixType<GradType>::value,
106106
typename MatType::elem_type>::type
107107
Optimize(FunctionType& function,
108108
MatType& iterate,

include/ensmallen_bits/fasta/fasta_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ FASTA<BackwardStepType>::FASTA(BackwardStepType backwardStep,
9090
template<typename BackwardStepType>
9191
template<typename FunctionType, typename MatType, typename GradType,
9292
typename... CallbackTypes>
93-
typename std::enable_if<IsArmaType<GradType>::value,
93+
typename std::enable_if<IsMatrixType<GradType>::value,
9494
typename MatType::elem_type>::type
9595
FASTA<BackwardStepType>::Optimize(FunctionType& function,
9696
MatType& iterateIn,
@@ -358,7 +358,7 @@ FASTA<BackwardStepType>::Optimize(FunctionType& function,
358358
// proximal step.
359359

360360
// Compute residual. This is Eq. (40) in the paper.
361-
const ElemType residual = arma::norm(g + (xHat - x) / currentStepSize, 2);
361+
const ElemType residual = norm(g + (xHat - x) / currentStepSize, 2);
362362

363363
// If this is the first iteration, store the residual as the first residual.
364364
if (i == 1)

include/ensmallen_bits/fbs/fbs.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class FBS
9292
*/
9393
template<typename FunctionType, typename MatType, typename GradType,
9494
typename... CallbackTypes>
95-
typename std::enable_if<IsArmaType<GradType>::value,
95+
typename std::enable_if<IsMatrixType<GradType>::value,
9696
typename MatType::elem_type>::type
9797
Optimize(FunctionType& function,
9898
MatType& iterate,

include/ensmallen_bits/fbs/fbs_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ FBS<BackwardStepType>::FBS(BackwardStepType backwardStep,
4444
template<typename BackwardStepType>
4545
template<typename FunctionType, typename MatType, typename GradType,
4646
typename... CallbackTypes>
47-
typename std::enable_if<IsArmaType<GradType>::value,
47+
typename std::enable_if<IsMatrixType<GradType>::value,
4848
typename MatType::elem_type>::type
4949
FBS<BackwardStepType>::Optimize(FunctionType& function,
5050
MatType& iterateIn,

include/ensmallen_bits/fbs/l1_constraint_impl.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,19 @@ void L1Constraint::ProximalStep(MatType& coordinates,
134134
}
135135

136136
const eT theta = (s - eT(lambda)) / rho;
137-
coordinates.transform(
138-
[theta](eT val)
139-
{
140-
if (val > 0)
141-
return std::max(val - theta, eT(0));
142-
else
143-
return std::min(val + theta, eT(0));
144-
});
137+
// This is a single-line implementation of the .transform() below; we use the
138+
// single-line implementation so it works with Bandicoot.
139+
//
140+
// coordinates.transform(
141+
// [theta](eT val)
142+
// {
143+
// if (val > 0)
144+
// return std::max(val - theta, eT(0));
145+
// else
146+
// return std::min(val + theta, eT(0));
147+
// });
148+
coordinates = sign(coordinates) % clamp(
149+
abs(coordinates) - theta, eT(0), std::numeric_limits<eT>::max());
145150

146151
// Sanity check: ensure we actually ended up inside the L1 ball. This might
147152
// not happen due to floating-point inaccuracies. If so, try again.
@@ -168,7 +173,8 @@ template<typename MatType>
168173
inline arma::Col<typename MatType::elem_type> L1Constraint::ExtractNonzeros(
169174
const MatType& coordinates) const
170175
{
171-
return arma::Col<typename MatType::elem_type>(vectorise(abs(coordinates)));
176+
typedef typename MatType::elem_type ElemType;
177+
return conv_to<arma::Col<ElemType>>::from(vectorise(abs(coordinates)));
172178
}
173179

174180
template<typename eT>

include/ensmallen_bits/fbs/l1_penalty_impl.hpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,24 @@ void L1Penalty::ProximalStep(MatType& coordinates,
3838
// Apply the backwards step coordinate-wise. If `MatType` is sparse, this
3939
// only applies to nonzero elements, which is just fine.
4040
typedef typename MatType::elem_type eT;
41-
coordinates.transform([this, stepSize](eT val) { return (val > eT(0)) ?
42-
(std::max(eT(0), val - eT(lambda * stepSize))) :
43-
(std::min(eT(0), val + eT(lambda * stepSize))); });
41+
42+
// This is equivalent to the following .transform() implementation (which is
43+
// easier to read but will not work with Bandicoot):
44+
//
45+
//arma::Mat<typename MatType::elem_type> c2 = conv_to<arma::Mat<typename MatType::elem_type>>::from(coordinates);
46+
//c2.transform([this, stepSize](eT val) { return (val > eT(0)) ?
47+
// (std::max(eT(0), val - eT(lambda * stepSize))) :
48+
// (std::min(eT(0), val + eT(lambda * stepSize))); });
49+
// coordinates.transform([this, stepSize](eT val) { return (val > eT(0)) ?
50+
// (std::max(eT(0), val - eT(lambda * stepSize))) :
51+
// (std::min(eT(0), val + eT(lambda * stepSize))); });
52+
//
53+
coordinates = sign(coordinates) % clamp(
54+
abs(coordinates) - eT(lambda * stepSize), eT(0),
55+
std::numeric_limits<eT>::max());
56+
57+
//coordinates.print("coordinates");
58+
//c2.print("c2");
4459
}
4560

4661
} // namespace ens

include/ensmallen_bits/fista/fista.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class FISTA
104104
*/
105105
template<typename FunctionType, typename MatType, typename GradType,
106106
typename... CallbackTypes>
107-
typename std::enable_if<IsArmaType<GradType>::value,
107+
typename std::enable_if<IsMatrixType<GradType>::value,
108108
typename MatType::elem_type>::type
109109
Optimize(FunctionType& function,
110110
MatType& iterate,

include/ensmallen_bits/fista/fista_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ FISTA<BackwardStepType>::FISTA(BackwardStepType backwardStep,
7474
template<typename BackwardStepType>
7575
template<typename FunctionType, typename MatType, typename GradType,
7676
typename... CallbackTypes>
77-
typename std::enable_if<IsArmaType<GradType>::value,
77+
typename std::enable_if<IsMatrixType<GradType>::value,
7878
typename MatType::elem_type>::type
7979
FISTA<BackwardStepType>::Optimize(FunctionType& function,
8080
MatType& iterateIn,

tests/ada_belief_test.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,8 @@ using namespace ens;
2020
using namespace ens::test;
2121

2222
TEMPLATE_TEST_CASE("AdaBelief_LogisticRegressionFunction", "[AdaBelief]",
23-
ENS_ALL_TEST_TYPES)
23+
ENS_ALL_CPU_TEST_TYPES)
2424
{
2525
AdaBelief adaBelief(0.032);
26-
LogisticRegressionFunctionTest<TestType, arma::Row<size_t>>(adaBelief);
26+
LogisticRegressionFunctionTest<TestType>(adaBelief);
2727
}
28-
29-
#ifdef ENS_HAVE_COOT
30-
31-
TEMPLATE_TEST_CASE("AdaBelief_LogisticRegressionFunction", "[AdaBelief]",
32-
coot::mat, coot::fmat)
33-
{
34-
AdaBelief adaBelief(0.032);
35-
LogisticRegressionFunctionTest<TestType, coot::Row<size_t>>(adaBelief);
36-
}
37-
38-
#endif

0 commit comments

Comments
 (0)