Skip to content

Commit ab12fa3

Browse files
committed
add SPARSE boolean for clarity
1 parent f37b234 commit ab12fa3

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

include/random_walks/gaussian_accelerated_billiard_walk.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ struct GaussianAcceleratedBilliardWalk
7373
typedef typename Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> DenseMT;
7474
typedef typename Polytope::VT VT;
7575
typedef typename Point::FT NT;
76+
// We do sparse computations iff MT is sparse rowMajor
77+
static constexpr bool SPARSE = std::is_same_v<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>;
7678
// AA is sparse colMajor if MT is sparse rowMajor, and Dense otherwise
77-
using AA_type = std::conditional_t< std::is_same_v<MT, typename Eigen::SparseMatrix<NT, Eigen::RowMajor>>, typename Eigen::SparseMatrix<NT>, DenseMT >;
79+
using AA_type = std::conditional_t< SPARSE, typename Eigen::SparseMatrix<NT>, DenseMT >;
7880
// AE is sparse rowMajor if (MT is sparse rowMajor and E is sparse), and Dense otherwise
79-
using AE_type = std::conditional_t< std::is_same_v<MT, typename Eigen::SparseMatrix<NT, Eigen::RowMajor>> && std::is_base_of<Eigen::SparseMatrixBase<E_type>, E_type>::value, typename Eigen::SparseMatrix<NT, Eigen::RowMajor>, DenseMT >;
81+
using AE_type = std::conditional_t< SPARSE && std::is_base_of<Eigen::SparseMatrixBase<E_type>, E_type>::value, typename Eigen::SparseMatrix<NT, Eigen::RowMajor>, DenseMT >;
8082

8183
void computeLcov(const E_type E)
8284
{
@@ -117,10 +119,10 @@ struct GaussianAcceleratedBilliardWalk
117119
_L = compute_diameter<GenericPolytope>::template compute<NT>(P);
118120
computeLcov(E);
119121
_E = E;
120-
if constexpr (std::is_same<AA_type, Eigen::SparseMatrix<NT>>::value) {
122+
if constexpr (SPARSE) {
121123
_AA = (P.get_mat() * P.get_mat().transpose());
122124
} else {
123-
_AA.noalias() = (DenseMT)(P.get_mat() * P.get_mat().transpose());
125+
_AA.noalias() = (P.get_mat() * P.get_mat().transpose());
124126
}
125127
_rho = 1000 * P.dimension(); // upper bound for the number of reflections (experimental)
126128
initialize(P, p, rng);
@@ -142,7 +144,7 @@ struct GaussianAcceleratedBilliardWalk
142144
::template compute<NT>(P);
143145
computeLcov(E);
144146
_E = E;
145-
if constexpr (std::is_same<AA_type, Eigen::SparseMatrix<NT>>::value) {
147+
if constexpr (SPARSE) {
146148
_AA = (P.get_mat() * P.get_mat().transpose());
147149
} else {
148150
_AA.noalias() = (DenseMT)(P.get_mat() * P.get_mat().transpose());
@@ -186,7 +188,7 @@ struct GaussianAcceleratedBilliardWalk
186188
}
187189

188190
_lambda_prev = dl * pbpair.first;
189-
if constexpr (std::is_same<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>::value) {
191+
if constexpr (SPARSE) {
190192
typename Point::Coeff b;
191193
NT* b_data;
192194
b = P.get_vec();
@@ -213,7 +215,7 @@ struct GaussianAcceleratedBilliardWalk
213215
while (it < _rho)
214216
{
215217
std::pair<NT, int> pbpair;
216-
if constexpr (std::is_same<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>::value) {
218+
if constexpr (SPARSE) {
217219
pbpair = P.line_positive_intersect(_p, _lambdas, _Av, _lambda_prev,
218220
_distances_set, _AA, _update_parameters);
219221
} else {
@@ -227,7 +229,7 @@ struct GaussianAcceleratedBilliardWalk
227229
break;
228230
}
229231
_lambda_prev = dl * pbpair.first;
230-
if constexpr (std::is_same<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>::value) {
232+
if constexpr (SPARSE) {
231233
_update_parameters.moved_dist += _lambda_prev;
232234
} else {
233235
_p += (_lambda_prev * _v);

include/random_walks/uniform_accelerated_billiard_walk.hpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ struct AcceleratedBilliardWalk
6363
typedef typename Polytope::MT MT;
6464
typedef typename Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> DenseMT;
6565
typedef typename Point::FT NT;
66-
using AA_type = std::conditional_t< std::is_same_v<MT, typename Eigen::SparseMatrix<NT, Eigen::RowMajor>>, typename Eigen::SparseMatrix<NT>, DenseMT >;
66+
// We do sparse computations iff MT is sparse rowMajor
67+
static constexpr bool SPARSE = std::is_same_v<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>;
6768
// AA is sparse colMajor if MT is sparse rowMajor, and Dense otherwise
69+
using AA_type = std::conditional_t< SPARSE, typename Eigen::SparseMatrix<NT>, DenseMT >;
6870

6971
template <typename GenericPolytope>
7072
Walk(GenericPolytope &P, Point const& p, RandomNumberGenerator &rng)
@@ -75,7 +77,7 @@ struct AcceleratedBilliardWalk
7577
_update_parameters = update_parameters();
7678
_L = compute_diameter<GenericPolytope>
7779
::template compute<NT>(P);
78-
if constexpr (std::is_same<AA_type, Eigen::SparseMatrix<NT>>::value) {
80+
if constexpr (SPARSE) {
7981
_AA = (P.get_mat() * P.get_mat().transpose());
8082
} else {
8183
_AA.noalias() = (DenseMT)(P.get_mat() * P.get_mat().transpose());
@@ -95,7 +97,7 @@ struct AcceleratedBilliardWalk
9597
_L = params.set_L ? params.m_L
9698
: compute_diameter<GenericPolytope>
9799
::template compute<NT>(P);
98-
if constexpr (std::is_same<AA_type, Eigen::SparseMatrix<NT>>::value) {
100+
if constexpr (SPARSE) {
99101
_AA = (P.get_mat() * P.get_mat().transpose());
100102
} else {
101103
_AA.noalias() = (DenseMT)(P.get_mat() * P.get_mat().transpose());
@@ -119,7 +121,7 @@ struct AcceleratedBilliardWalk
119121
int it;
120122
typename Point::Coeff b;
121123
NT* b_data;
122-
if constexpr (std::is_same<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>::value) {
124+
if constexpr (SPARSE) {
123125
b = P.get_vec();
124126
b_data = b.data();
125127
}
@@ -140,7 +142,7 @@ struct AcceleratedBilliardWalk
140142
}
141143

142144
_lambda_prev = dl * pbpair.first;
143-
if constexpr (std::is_same<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>::value) {
145+
if constexpr (SPARSE) {
144146
_update_parameters.moved_dist = _lambda_prev;
145147
NT* Ar_data = _lambdas.data();
146148
NT* Av_data = _Av.data();
@@ -153,13 +155,13 @@ struct AcceleratedBilliardWalk
153155
_p += (_lambda_prev * _v);
154156
}
155157
T -= _lambda_prev;
156-
P.compute_reflection(_v, _p, _update_parameters);
158+
P.compute_reflection_abw(_v, _p, _update_parameters);
157159
it++;
158160

159161
while (it < _rho)
160162
{
161163
std::pair<NT, int> pbpair;
162-
if constexpr (std::is_same<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>::value) {
164+
if constexpr (SPARSE) {
163165
pbpair = P.line_positive_intersect(_p, _lambdas, _Av, _lambda_prev,
164166
_distances_set, _AA, _update_parameters);
165167
} else {
@@ -172,13 +174,13 @@ struct AcceleratedBilliardWalk
172174
break;
173175
}
174176
_lambda_prev = dl * pbpair.first;
175-
if constexpr (std::is_same<MT, Eigen::SparseMatrix<NT, Eigen::RowMajor>>::value) {
177+
if constexpr (SPARSE) {
176178
_update_parameters.moved_dist += _lambda_prev;
177179
} else {
178180
_p += (_lambda_prev * _v);
179181
}
180182
T -= _lambda_prev;
181-
P.compute_reflection(_v, _p, _update_parameters);
183+
P.compute_reflection_abw(_v, _p, _update_parameters);
182184
it++;
183185
}
184186
_p += _update_parameters.moved_dist * _v;
@@ -300,7 +302,7 @@ struct AcceleratedBilliardWalk
300302
_lambda_prev = dl * pbpair.first;
301303
_p += (_lambda_prev * _v);
302304
T -= _lambda_prev;
303-
P.compute_reflection(_v, _p, _update_parameters);
305+
P.compute_reflection_abw(_v, _p, _update_parameters);
304306

305307
while (it <= _rho)
306308
{
@@ -318,7 +320,7 @@ struct AcceleratedBilliardWalk
318320
_lambda_prev = dl * pbpair.first;
319321
_p += (_lambda_prev * _v);
320322
T -= _lambda_prev;
321-
P.compute_reflection(_v, _p, _update_parameters);
323+
P.compute_reflection_abw(_v, _p, _update_parameters);
322324
it++;
323325
}
324326
}

0 commit comments

Comments
 (0)