Skip to content

Commit cdcb43c

Browse files
authored
fix no_multiphase sampling interface (#116)
1 parent a965357 commit cdcb43c

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

dingo/bindings/bindings.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,50 +112,50 @@ double HPolytopeCPP::apply_sampling(int walk_len,
112112

113113
NT variance = variance_value;
114114

115-
if (strcmp(method, "cdhr")) { // cdhr
115+
if (strcmp(method, "cdhr") == 0) { // cdhr
116116
uniform_sampling<CDHRWalk>(rand_points, HP, rng, walk_len, number_of_points,
117117
starting_point, number_of_points_to_burn);
118-
} else if (strcmp(method, "rdhr")) { // rdhr
118+
} else if (strcmp(method, "rdhr") == 0) { // rdhr
119119
uniform_sampling<RDHRWalk>(rand_points, HP, rng, walk_len, number_of_points,
120120
starting_point, number_of_points_to_burn);
121-
} else if (strcmp(method, "billiard_walk")) { // accelerated_billiard
121+
} else if (strcmp(method, "billiard_walk") == 0) { // accelerated_billiard
122122
uniform_sampling<AcceleratedBilliardWalk>(rand_points, HP, rng, walk_len,
123123
number_of_points, starting_point,
124124
number_of_points_to_burn);
125-
} else if (strcmp(method, "ball_walk")) { // ball walk
125+
} else if (strcmp(method, "ball_walk") == 0) { // ball walk
126126
uniform_sampling<BallWalk>(rand_points, HP, rng, walk_len, number_of_points,
127127
starting_point, number_of_points_to_burn);
128-
} else if (strcmp(method, "dikin_walk")) { // dikin walk
128+
} else if (strcmp(method, "dikin_walk") == 0) { // dikin walk
129129
uniform_sampling<DikinWalk>(rand_points, HP, rng, walk_len, number_of_points,
130130
starting_point, number_of_points_to_burn);
131-
} else if (strcmp(method, "john_walk")) { // john walk
131+
} else if (strcmp(method, "john_walk") == 0) { // john walk
132132
uniform_sampling<JohnWalk>(rand_points, HP, rng, walk_len, number_of_points,
133133
starting_point, number_of_points_to_burn);
134-
} else if (strcmp(method, "vaidya_walk")) { // vaidya walk
134+
} else if (strcmp(method, "vaidya_walk") == 0) { // vaidya walk
135135
uniform_sampling<VaidyaWalk>(rand_points, HP, rng, walk_len, number_of_points,
136136
starting_point, number_of_points_to_burn);
137-
} else if (strcmp(method, "mmcs")) { // vaidya walk
137+
} else if (strcmp(method, "mmcs") == 0) { // vaidya walk
138138
MT S;
139139
int total_ess;
140140
//TODO: avoid passing polytopes as non-const references
141141
const Hpolytope HP_const = HP;
142142
mmcs(HP_const, ess, S, total_ess, walk_len, rng);
143143
samples = S.data();
144-
} else if (strcmp(method, "gaussian_hmc_walk")) { // Gaussian sampling with exact HMC walk
144+
} else if (strcmp(method, "gaussian_hmc_walk") == 0) { // Gaussian sampling with exact HMC walk
145145
NT a = NT(1)/(NT(2)*variance);
146146
gaussian_sampling<GaussianHamiltonianMonteCarloExactWalk>(rand_points, HP, rng, walk_len, number_of_points, a,
147147
starting_point, number_of_points_to_burn);
148-
} else if (strcmp(method, "exponential_hmc_walk")) { // exponential sampling with exact HMC walk
148+
} else if (strcmp(method, "exponential_hmc_walk") == 0) { // exponential sampling with exact HMC walk
149149
VT c(d);
150150
for (int i = 0; i < d; i++){
151151
c(i) = bias_vector_[i];
152152
}
153153
Point bias_vector(c);
154154
exponential_sampling<ExponentialHamiltonianMonteCarloExactWalk>(rand_points, HP, rng, walk_len, number_of_points, bias_vector, variance,
155155
starting_point, number_of_points_to_burn);
156-
} else if (strcmp(method, "hmc_leapfrog_gaussian")) { // HMC with Gaussian distribution
156+
} else if (strcmp(method, "hmc_leapfrog_gaussian") == 0) { // HMC with Gaussian distribution
157157
rand_points = hmc_leapfrog_gaussian(walk_len, number_of_points, number_of_points_to_burn, variance, starting_point, HP);
158-
} else if (strcmp(method, "hmc_leapfrog_exponential")) { // HMC with exponential distribution
158+
} else if (strcmp(method, "hmc_leapfrog_exponential") == 0) { // HMC with exponential distribution
159159
VT c(d);
160160
for (int i = 0; i < d; i++) {
161161
c(i) = bias_vector_[i];
@@ -170,7 +170,7 @@ double HPolytopeCPP::apply_sampling(int walk_len,
170170
throw std::runtime_error("This function must not be called.");
171171
}
172172

173-
if (!strcmp(method, "mmcs")) {
173+
if (strcmp(method, "mmcs") != 0) {
174174
// The following block of code allows us to copy the sampled points
175175
auto n_si=0;
176176
for (auto it_s = rand_points.cbegin(); it_s != rand_points.cend(); it_s++){

0 commit comments

Comments
 (0)