Skip to content

Commit 6845f24

Browse files
Init bc model update.
* Does not use ExpBeGone trick. * Metropolis works but NUTS not yet.
1 parent f37ba52 commit 6845f24

File tree

11 files changed

+279
-172
lines changed

11 files changed

+279
-172
lines changed

R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactio
2525
.Call(`_bgms_sample_omrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, iter)
2626
}
2727

28-
sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter) {
29-
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)
28+
sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter) {
29+
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter)
3030
}
3131

3232
compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) {

R/bgm.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,9 @@ bgm = function(
524524
# Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE)
525525
bc_vars = which(!variable_bool)
526526
for(i in bc_vars) {
527-
blume_capel_stats[1, i] = sum(x[, i])
527+
blume_capel_stats[1, i] = sum(x[, i] - baseline_category[i])
528528
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2)
529+
x[, i] = x[, i] - baseline_category[i]
529530
}
530531
}
531532
pairwise_stats = t(x) %*% x
@@ -588,7 +589,6 @@ bgm = function(
588589
nThreads = cores, seed = seed, progress_type = progress_type
589590
)
590591

591-
592592
userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
593593
if (userInterrupt) {
594594
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")

R/bgmCompare.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,9 @@ bgmCompare = function(
402402
blume_capel_stats = compute_blume_capel_stats(
403403
x, baseline_category, ordinal_variable, group
404404
)
405+
for (i in which(!ordinal_variable)) {
406+
x[, i] = sum(x[, i] - baseline_category[i])
407+
}
405408

406409
# Compute sufficient statistics for pairwise interactions
407410
pairwise_stats = compute_pairwise_stats(

R/data_utils.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, gro
210210
sufficient_stats = matrix(0, nrow = 2, ncol = ncol(x))
211211
bc_vars = which(!ordinal_variable)
212212
for (i in bc_vars) {
213-
sufficient_stats[1, i] = sum(x[, i])
213+
sufficient_stats[1, i] = sum(x[, i] - baseline_category[i])
214214
sufficient_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2)
215215
}
216216
return(sufficient_stats)
@@ -220,7 +220,7 @@ compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, gro
220220
sufficient_stats_gr = matrix(0, nrow = 2, ncol = ncol(x))
221221
bc_vars = which(!ordinal_variable)
222222
for (i in bc_vars) {
223-
sufficient_stats_gr[1, i] = sum(x[group == g, i])
223+
sufficient_stats_gr[1, i] = sum(x[group == g, i] - baseline_category[i])
224224
sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i]) ^ 2)
225225
}
226226
sufficient_stats[[g]] = sufficient_stats_gr

R/sampleMRF.R

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#' in specifying their model.
1414
#'
1515
#' The Blume-Capel option is specifically designed for ordinal variables that
16-
#' have a special type of reference_category category, such as the neutral
16+
#' have a special type of baseline_category category, such as the neutral
1717
#' category in a Likert scale. The Blume-Capel model specifies the following
1818
#' quadratic model for the threshold parameters:
1919
#' \deqn{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}{{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}}
@@ -23,8 +23,8 @@
2323
#' \eqn{\alpha > 0}{\alpha > 0} and decreasing threshold values if
2424
#' \eqn{\alpha <0}{\alpha <0}), if \eqn{\beta < 0}{\beta < 0}, it offers an
2525
#' increasing penalty for responding in a category further away from the
26-
#' reference_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a
27-
#' preference for responding in the reference_category category.
26+
#' baseline_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a
27+
#' preference for responding in the baseline_category category.
2828
#'
2929
#' @param no_states The number of states of the ordinal MRF to be generated.
3030
#'
@@ -53,8 +53,8 @@
5353
#' ``blume-capel''. Binary variables are automatically treated as ``ordinal’’.
5454
#' Defaults to \code{variable_type = "ordinal"}.
5555
#'
56-
#' @param reference_category An integer vector of length \code{no_variables} specifying the
57-
#' reference_category category that is used for the Blume-Capel model (details below).
56+
#' @param baseline_category An integer vector of length \code{no_variables} specifying the
57+
#' baseline_category category that is used for the Blume-Capel model (details below).
5858
#' Can be any integer value between \code{0} and \code{no_categories} (or
5959
#' \code{no_categories[i]}).
6060
#'
@@ -103,7 +103,7 @@
103103
#' interactions = Interactions,
104104
#' thresholds = Thresholds,
105105
#' variable_type = c("b","b","o","b","o"),
106-
#' reference_category = 2)
106+
#' baseline_category = 2)
107107
#'
108108
#' @export
109109
mrfSampler = function(no_states,
@@ -112,7 +112,7 @@ mrfSampler = function(no_states,
112112
interactions,
113113
thresholds,
114114
variable_type = "ordinal",
115-
reference_category,
115+
baseline_category,
116116
iter = 1e3) {
117117
# Check no_states, no_variables, iter --------------------------------------------
118118
if(no_states <= 0 ||
@@ -168,20 +168,20 @@ mrfSampler = function(no_states,
168168
}
169169
}
170170

171-
# Check the reference_category for Blume-Capel variables ---------------------
171+
# Check the baseline_category for Blume-Capel variables ---------------------
172172
if(any(variable_type == "blume-capel")) {
173-
if(length(reference_category) == 1) {
174-
reference_category = rep(reference_category, no_variables)
173+
if(length(baseline_category) == 1) {
174+
baseline_category = rep(baseline_category, no_variables)
175175
}
176-
if(any(reference_category < 0) || any(abs(reference_category - round(reference_category)) > .Machine$double.eps)) {
176+
if(any(baseline_category < 0) || any(abs(baseline_category - round(baseline_category)) > .Machine$double.eps)) {
177177
stop(paste0("For variables ",
178-
which(reference_category < 0),
179-
" ``reference_category'' was either negative or not integer."))
178+
which(baseline_category < 0),
179+
" ``baseline_category'' was either negative or not integer."))
180180
}
181-
if(any(reference_category - no_categories > 0)) {
181+
if(any(baseline_category - no_categories > 0)) {
182182
stop(paste0("For variables ",
183-
which(reference_category - no_categories > 0),
184-
" the ``reference_category'' category was larger than the maximum category value."))
183+
which(baseline_category - no_categories > 0),
184+
" the ``baseline_category'' category was larger than the maximum category value."))
185185
}
186186
}
187187

@@ -300,7 +300,7 @@ mrfSampler = function(no_states,
300300
interactions = interactions,
301301
thresholds = thresholds,
302302
variable_type = variable_type,
303-
reference_category = reference_category,
303+
baseline_category = baseline_category,
304304
iter = iter)
305305
}
306306

src/RcppExports.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ BEGIN_RCPP
148148
END_RCPP
149149
}
150150
// sample_bcomrf_gibbs
151-
IntegerMatrix sample_bcomrf_gibbs(int no_states, int no_variables, IntegerVector no_categories, NumericMatrix interactions, NumericMatrix thresholds, StringVector variable_type, IntegerVector reference_category, int iter);
152-
RcppExport SEXP _bgms_sample_bcomrf_gibbs(SEXP no_statesSEXP, SEXP no_variablesSEXP, SEXP no_categoriesSEXP, SEXP interactionsSEXP, SEXP thresholdsSEXP, SEXP variable_typeSEXP, SEXP reference_categorySEXP, SEXP iterSEXP) {
151+
IntegerMatrix sample_bcomrf_gibbs(int no_states, int no_variables, IntegerVector no_categories, NumericMatrix interactions, NumericMatrix thresholds, StringVector variable_type, IntegerVector baseline_category, int iter);
152+
RcppExport SEXP _bgms_sample_bcomrf_gibbs(SEXP no_statesSEXP, SEXP no_variablesSEXP, SEXP no_categoriesSEXP, SEXP interactionsSEXP, SEXP thresholdsSEXP, SEXP variable_typeSEXP, SEXP baseline_categorySEXP, SEXP iterSEXP) {
153153
BEGIN_RCPP
154154
Rcpp::RObject rcpp_result_gen;
155155
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -159,9 +159,9 @@ BEGIN_RCPP
159159
Rcpp::traits::input_parameter< NumericMatrix >::type interactions(interactionsSEXP);
160160
Rcpp::traits::input_parameter< NumericMatrix >::type thresholds(thresholdsSEXP);
161161
Rcpp::traits::input_parameter< StringVector >::type variable_type(variable_typeSEXP);
162-
Rcpp::traits::input_parameter< IntegerVector >::type reference_category(reference_categorySEXP);
162+
Rcpp::traits::input_parameter< IntegerVector >::type baseline_category(baseline_categorySEXP);
163163
Rcpp::traits::input_parameter< int >::type iter(iterSEXP);
164-
rcpp_result_gen = Rcpp::wrap(sample_bcomrf_gibbs(no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter));
164+
rcpp_result_gen = Rcpp::wrap(sample_bcomrf_gibbs(no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter));
165165
return rcpp_result_gen;
166166
END_RCPP
167167
}

src/bgmCompare_logp_and_grad.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ double log_pseudoposterior(
160160
const double quad_effect = main_group(v, 1);
161161
const int ref = baseline_category(v);
162162
for (int c = 0; c <= num_cats; ++c) {
163-
const int centered = c - ref;
164-
const double quad = quad_effect * centered * centered;
165-
const double lin = lin_effect * c;
166-
const arma::vec exponent = lin + quad + c * rest_score - bound;
163+
const int score = c - ref;
164+
const double lin = lin_effect * score;
165+
const double quad = quad_effect * score * score;
166+
const arma::vec exponent = lin + quad + score * rest_score - bound;
167167
denom += ARMA_MY_EXP(exponent);
168168
}
169169
}
@@ -566,10 +566,10 @@ arma::vec gradient(
566566
const double lin_effect = main_group(v, 0);
567567
const double quad_effect = main_group(v, 1);
568568
for (int s = 0; s <= K; ++s) {
569-
const int centered = s - ref;
570-
const double lin = lin_effect * s;
571-
const double quad = quad_effect * centered * centered;
572-
exponents.col(s) = lin + quad + s * rest_score - bound;
569+
const int score = s - ref;
570+
const double lin = lin_effect * score;
571+
const double quad = quad_effect * score * score;
572+
exponents.col(s) = lin + quad + score * rest_score - bound;
573573
}
574574
}
575575

@@ -594,7 +594,7 @@ arma::vec gradient(
594594
}
595595
}
596596
} else {
597-
arma::vec lin_score = arma::regspace<arma::vec>(0, K); // length K+1
597+
arma::vec lin_score = arma::regspace<arma::vec>(0 - ref, K - ref); // length K+1
598598
arma::vec quad_score = arma::square(lin_score - ref);
599599

600600
double sum_lin = arma::accu(probs * lin_score);
@@ -619,8 +619,15 @@ arma::vec gradient(
619619
if (v == v2) continue;
620620

621621
arma::vec expected_value(num_group_obs, arma::fill::zeros);
622-
for (int s = 1; s <= K; ++s) {
623-
expected_value += s * probs.col(s) % obs.col(v2);
622+
if (is_ordinal_variable(v)) {
623+
for (int s = 1; s <= K; ++s) {
624+
expected_value += s * probs.col(s) % obs.col(v2);
625+
}
626+
} else {
627+
for (int s = 0; s <= K; ++s) {
628+
int score = s - ref;
629+
expected_value += score * probs.col(s) % obs.col(v2);
630+
}
624631
}
625632
double sum_expectation = arma::accu(expected_value);
626633

@@ -860,10 +867,10 @@ double log_pseudoposterior_main_component(
860867
const double quad_effect = main_group(variable, 1);
861868
const int ref = baseline_category(variable);
862869
for (int cat = 0; cat <= num_cats; cat++) {
863-
const int centered = cat - ref;
864-
const double quad = quad_effect * centered * centered;
865-
const double lin = lin_effect * cat;
866-
const arma::vec exponent = lin + quad + cat * rest_score - bound;
870+
const int score = cat - ref;
871+
const double quad = quad_effect * score * score;
872+
const double lin = lin_effect * score;
873+
const arma::vec exponent = lin + quad + score * rest_score - bound;
867874
denom += ARMA_MY_EXP(exponent);
868875
}
869876
}
@@ -1044,10 +1051,10 @@ double log_pseudoposterior_pair_component(
10441051
const double quad_effect = main_group(v, 1);
10451052
const int ref = baseline_category(v);
10461053
for (int c = 0; c <= num_cats; ++c) {
1047-
const int centered = c - ref;
1048-
const double quad = quad_effect * centered * centered;
1049-
const double lin = lin_effect * c;
1050-
const arma::vec exponent = lin + quad + c * rest_score - bound;
1054+
const int score = c - ref;
1055+
const double lin = lin_effect * score;
1056+
const double quad = quad_effect * score * score;
1057+
const arma::vec exponent = lin + quad + score * rest_score - bound;
10511058
denom += ARMA_MY_EXP(exponent);
10521059
}
10531060
}
@@ -1192,9 +1199,9 @@ double log_ratio_pseudolikelihood_constant_variable(
11921199
arma::vec const_current(num_cats + 1, arma::fill::zeros);
11931200
arma::vec const_proposed(num_cats + 1, arma::fill::zeros);
11941201
for (int s = 0; s <= num_cats; ++s) {
1195-
const int centered = s - ref;
1196-
const_current(s) = main_current(0) * s + main_current(1) * centered * centered;
1197-
const_proposed(s) = main_proposed(0) * s + main_proposed(1) * centered * centered;
1202+
const int score = s - ref;
1203+
const_current(s) = main_current(0) * score + main_current(1) * score* score;
1204+
const_proposed(s) = main_proposed(0) * score + main_proposed(1) * score * score;
11981205
}
11991206

12001207
double lbound = std::max(const_current.max(), const_proposed.max());
@@ -1204,8 +1211,9 @@ double log_ratio_pseudolikelihood_constant_variable(
12041211
bound_proposed = lbound + num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf);
12051212

12061213
for (int s = 0; s <= num_cats; ++s) {
1207-
denom_current += ARMA_MY_EXP(const_current(s) + s * rest_current - bound_current);
1208-
denom_proposed += ARMA_MY_EXP(const_proposed(s) + s * rest_proposed - bound_proposed);
1214+
const int score = s - ref;
1215+
denom_current += ARMA_MY_EXP(const_current(s) + score * rest_current - bound_current);
1216+
denom_proposed += ARMA_MY_EXP(const_proposed(s) + score * rest_proposed - bound_proposed);
12091217
}
12101218
}
12111219

src/bgmCompare_sampler.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void impute_missing_bgmcompare(
8989

9090
arma::vec category_response_probabilities(max_num_categories + 1);
9191
double exponent, cumsum, u;
92-
int score, person, variable, new_observation, old_observation, group;
92+
int score, person, variable, new_value, old_value, group;
9393

9494
//Impute missing data
9595
for(int missing = 0; missing < num_missings; missing++) {
@@ -132,12 +132,12 @@ void impute_missing_bgmcompare(
132132
} else {
133133
// For Blume-Capel variables
134134
cumsum = 0.0;
135+
const int ref = baseline_category[variable];
135136
for(int category = 0; category <= num_categories(variable); category++) {
136-
exponent = group_main_effects[0] * category;
137-
exponent += group_main_effects[1] *
138-
(category - baseline_category[variable]) *
139-
(category - baseline_category[variable]);
140-
exponent += category * rest_score;
137+
score = category - ref;
138+
exponent = group_main_effects[0] * score;
139+
exponent += group_main_effects[1] * score * score;
140+
exponent += rest_score * score;
141141
cumsum += MY_EXP(exponent);
142142
category_response_probabilities[category] = cumsum;
143143
}
@@ -149,31 +149,30 @@ void impute_missing_bgmcompare(
149149
while (u > category_response_probabilities[score]) {
150150
score++;
151151
}
152-
new_observation = score;
153-
old_observation = observations(person, variable);
154152

155-
if(old_observation != new_observation) {
153+
new_value = score;
154+
if(!is_ordinal_variable[variable])
155+
new_value -= baseline_category[variable];
156+
old_value = observations(person, variable);
157+
158+
if(old_value != new_value) {
156159
// Update raw observations
157-
observations(person, variable) = new_observation;
160+
observations(person, variable) = new_value;
158161

159162
// Update sufficient statistics for main effects
160163
if(is_ordinal_variable[variable] == true) {
161164
arma::imat counts_per_category_group = counts_per_category[group];
162-
if(old_observation > 0)
163-
counts_per_category_group(old_observation-1, variable)--;
164-
if(new_observation > 0)
165-
counts_per_category_group(new_observation-1, variable)++;
165+
if(old_value > 0)
166+
counts_per_category_group(old_value-1, variable)--;
167+
if(new_value > 0)
168+
counts_per_category_group(new_value-1, variable)++;
166169
counts_per_category[group] = counts_per_category_group;
167170
} else {
168171
arma::imat blume_capel_stats_group = blume_capel_stats[group];
169-
blume_capel_stats_group(0, variable) -= old_observation;
170-
blume_capel_stats_group(0, variable) += new_observation;
171-
blume_capel_stats_group(1, variable) -=
172-
(old_observation - baseline_category[variable]) *
173-
(old_observation - baseline_category[variable]);
174-
blume_capel_stats_group(1, variable) +=
175-
(new_observation - baseline_category[variable]) *
176-
(new_observation - baseline_category[variable]);
172+
blume_capel_stats_group(0, variable) -= old_value;
173+
blume_capel_stats_group(0, variable) += new_value;
174+
blume_capel_stats_group(1, variable) -= old_value * old_value;
175+
blume_capel_stats_group(1, variable) += new_value * new_value;
177176
blume_capel_stats[group] = blume_capel_stats_group;
178177
}
179178

0 commit comments

Comments
 (0)