Skip to content

Commit 1be4861

Browse files
Fix na_impute bug bgmCompare.
1 parent 51ef241 commit 1be4861

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

src/bgmCompare_parallel.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct ChainResultCompare {
5858
* - Store results into the shared `results` vector at the chain index.
5959
*
6060
* Inputs (stored as const references or values):
61-
* - observations: Input observation matrix (persons × variables).
61+
* - observations_master: Input observation matrix (persons × variables).
6262
* - num_groups: Number of groups.
6363
* - counts_per_category_master: Group-level category counts.
6464
* - blume_capel_stats_master: Group-level Blume–Capel sufficient statistics.
@@ -99,7 +99,7 @@ struct ChainResultCompare {
9999
* - Errors are caught locally so one failing chain does not crash the entire run.
100100
*/
101101
struct GibbsCompareChainRunner : public Worker {
102-
const arma::imat& observations;
102+
const arma::imat& observations_master;
103103
const int num_groups;
104104
const std::vector<arma::imat>& counts_per_category_master;
105105
const std::vector<arma::imat>& blume_capel_stats_master;
@@ -138,7 +138,7 @@ struct GibbsCompareChainRunner : public Worker {
138138
std::vector<ChainResultCompare>& results;
139139

140140
GibbsCompareChainRunner(
141-
const arma::imat& observations,
141+
const arma::imat& observations_master,
142142
int num_groups,
143143
const std::vector<arma::imat>& counts_per_category_master,
144144
const std::vector<arma::imat>& blume_capel_stats_master,
@@ -174,7 +174,7 @@ struct GibbsCompareChainRunner : public Worker {
174174
ProgressManager& pm,
175175
std::vector<ChainResultCompare>& results
176176
) :
177-
observations(observations),
177+
observations_master(observations_master),
178178
num_groups(num_groups),
179179
counts_per_category_master(counts_per_category_master),
180180
blume_capel_stats_master(blume_capel_stats_master),
@@ -226,12 +226,12 @@ struct GibbsCompareChainRunner : public Worker {
226226
std::vector<arma::imat> blume_capel_stats = blume_capel_stats_master;
227227
std::vector<arma::mat> pairwise_stats = pairwise_stats_master;
228228
arma::mat inclusion_probability = inclusion_probability_master;
229-
arma::imat observations_copy = observations;
229+
arma::imat observations = observations_master;
230230

231231
// run sampler (pure C++)
232232
SamplerOutput result = run_gibbs_sampler_bgmCompare(
233233
out.chain_id,
234-
observations_copy,
234+
observations,
235235
num_groups,
236236
counts_per_category,
237237
blume_capel_stats,

src/bgmCompare_sampler.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,15 @@ void impute_missing_bgmcompare(
7979
std::vector<arma::imat>& counts_per_category,
8080
std::vector<arma::imat>& blume_capel_stats,
8181
std::vector<arma::mat>& pairwise_stats,
82-
const arma::imat& num_categories,
82+
const arma::ivec& num_categories,
8383
const arma::imat& missing_data_indices,
8484
const arma::uvec& is_ordinal_variable,
8585
const arma::ivec& baseline_category,
8686
SafeRNG& rng
8787
) {
8888
const int num_variables = observations.n_cols;
8989
const int num_missings = missing_data_indices.n_rows;
90-
const int max_num_categories = arma::max(arma::vectorise(num_categories));
90+
const int max_num_categories = arma::max(num_categories);
9191

9292
arma::vec category_response_probabilities(max_num_categories + 1);
9393
double exponent, cumsum, u;
@@ -120,11 +120,12 @@ void impute_missing_bgmcompare(
120120

121121
double rest_score =
122122
arma::as_scalar(observations.row(person) * group_pairwise_effects.col(variable));
123+
123124
if(is_ordinal_variable[variable] == true) {
124125
// For regular binary or ordinal variables
125126
cumsum = 1.0;
126127
category_response_probabilities[0] = 1.0;
127-
for(int category = 1; category <= num_categories(variable, group); category++) {
128+
for(int category = 1; category <= num_categories(variable); category++) {
128129
exponent = group_main_effects(category - 1);
129130
exponent += category * rest_score;
130131
cumsum += MY_EXP(exponent);
@@ -133,7 +134,7 @@ void impute_missing_bgmcompare(
133134
} else {
134135
// For Blume-Capel variables
135136
cumsum = 0.0;
136-
for(int category = 0; category <= num_categories(variable, group); category++) {
137+
for(int category = 0; category <= num_categories(variable); category++) {
137138
exponent = group_main_effects[0] * category;
138139
exponent += group_main_effects[1] *
139140
(category - baseline_category[variable]) *
@@ -161,9 +162,9 @@ void impute_missing_bgmcompare(
161162
if(is_ordinal_variable[variable] == true) {
162163
arma::imat counts_per_category_group = counts_per_category[group];
163164
if(old_observation > 0)
164-
counts_per_category_group(old_observation, variable)--;
165+
counts_per_category_group(old_observation-1, variable)--;
165166
if(new_observation > 0)
166-
counts_per_category_group(new_observation, variable)++;
167+
counts_per_category_group(new_observation-1, variable)++;
167168
counts_per_category[group] = counts_per_category_group;
168169
} else {
169170
arma::imat blume_capel_stats_group = blume_capel_stats[group];

0 commit comments

Comments
 (0)