Skip to content

Commit 2354706

Browse files
Exp be gone (#67)
* print mutex for printing for error handling * Init exp be gone for bgm() regular variable. * Add numerical analysis scripts to GH * Init bc model update. * Does not use ExpBeGone trick. * Metropolis works but NUTS not yet. * Fix mrfSampler documentation * Update dev/numerical_analysis files * Fix buildignore. Clean up BC-normalization code. Clean up dev/num_analysis for bc variables * Fix the denominator computations in the pseudolikelihood for Blume-Capel variables * Remove debugging tools for HMC components * Clean up c++ code for Blume-Capel * Update divergent transition warning * Update to prob computation of Blume-Capel variables in bgm() * Exp Be Gone for bgmCompare Bug fix to bgmCompare for BC variables * Cleanup * Update news.md
1 parent e68452b commit 2354706

24 files changed

+3318
-450
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
^doc$
1313
^Meta$
1414
^\.vscode$
15+
^dev/

NEWS.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
## Other changes
88

99
* reparameterized the Blume-capel model to use (score-baseline) instead of score.
10+
* implemented a new way to compute the denominators and probabilities. This made their computation both faster and more stable.
1011

1112
## Bug fixes
1213

13-
* Fixed numerical problems with Blume-Capel variables using HMC and NUTS for bgm().
14+
* fixed numerical problems with Blume-Capel variables using HMC and NUTS.
1415

1516
# bgms 0.1.6.1
1617

@@ -22,9 +23,9 @@
2223

2324
## Bug fixes
2425

25-
* Fixed a problem with warmup scheduling for adaptive-metropolis in bgmCompare()
26-
* Fixed stability problems with parallel sampling for bgm()
27-
* Fixed spurious output errors printing to console after user interrupt.
26+
* fixed a problem with warmup scheduling for adaptive-metropolis in bgmCompare()
27+
* fixed stability problems with parallel sampling for bgm()
28+
* fixed spurious output errors printing to console after user interrupt.
2829

2930
# bgms 0.1.6.0
3031

R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactio
2525
.Call(`_bgms_sample_omrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, iter)
2626
}
2727

28-
sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter) {
29-
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)
28+
sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter) {
29+
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter)
3030
}
3131

3232
compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) {

R/bgm.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,8 +560,9 @@ bgm = function(
560560
# Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE)
561561
bc_vars = which(!variable_bool)
562562
for(i in bc_vars) {
563-
blume_capel_stats[1, i] = sum(x[, i])
564-
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i])^2)
563+
blume_capel_stats[1, i] = sum(x[, i] - baseline_category[i])
564+
blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2)
565+
x[, i] = x[, i] - baseline_category[i]
565566
}
566567
}
567568
pairwise_stats = t(x) %*% x
@@ -627,7 +628,6 @@ bgm = function(
627628
nThreads = cores, seed = seed, progress_type = progress_type
628629
)
629630

630-
631631
userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
632632
if(userInterrupt) {
633633
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")

R/bgmCompare.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ bgmCompare = function(
321321
} else if(update_method == "hamiltonian-mc") {
322322
target_accept = 0.65
323323
} else if(update_method == "nuts") {
324-
target_accept = 0.80
324+
target_accept = 0.65
325325
}
326326
}
327327

@@ -414,13 +414,15 @@ bgmCompare = function(
414414
blume_capel_stats = compute_blume_capel_stats(
415415
x, baseline_category, ordinal_variable, group
416416
)
417+
for (i in which(!ordinal_variable)) {
418+
x[, i] = x[, i] - baseline_category[i]
419+
}
417420

418421
# Compute sufficient statistics for pairwise interactions
419422
pairwise_stats = compute_pairwise_stats(
420423
x, group
421424
)
422425

423-
424426
# Index vector used to sample interactions in a random order -----------------
425427
Index = matrix(0, nrow = num_interactions, ncol = 3)
426428
counter = 0
@@ -490,7 +492,6 @@ bgmCompare = function(
490492

491493
seed <- as.integer(seed)
492494

493-
494495
# Call the Rcpp function
495496
out = run_bgmCompare_parallel(
496497
observations = observations,

R/data_utils.R

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ compute_counts_per_category = function(x, num_categories, group = NULL) {
243243
counts_per_category_gr[category, variable] = sum(x[group == g, variable] == category)
244244
}
245245
}
246-
counts_per_category[[g]] = counts_per_category_gr
246+
counts_per_category[[length(counts_per_category) + 1]] = counts_per_category_gr
247247
}
248248
return(counts_per_category)
249249
}
@@ -253,34 +253,34 @@ compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, gro
253253
if(is.null(group)) { # One-group design
254254
sufficient_stats = matrix(0, nrow = 2, ncol = ncol(x))
255255
bc_vars = which(!ordinal_variable)
256-
for(i in bc_vars) {
257-
sufficient_stats[1, i] = sum(x[, i])
258-
sufficient_stats[2, i] = sum((x[, i] - baseline_category[i])^2)
256+
for (i in bc_vars) {
257+
sufficient_stats[1, i] = sum(x[, i] - baseline_category[i])
258+
sufficient_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2)
259259
}
260260
return(sufficient_stats)
261261
} else { # Multi-group design
262262
sufficient_stats = list()
263263
for(g in unique(group)) {
264264
sufficient_stats_gr = matrix(0, nrow = 2, ncol = ncol(x))
265265
bc_vars = which(!ordinal_variable)
266-
for(i in bc_vars) {
267-
sufficient_stats_gr[1, i] = sum(x[group == g, i])
268-
sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i])^2)
266+
for (i in bc_vars) {
267+
sufficient_stats_gr[1, i] = sum(x[group == g, i] - baseline_category[i])
268+
sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i]) ^ 2)
269269
}
270-
sufficient_stats[[g]] = sufficient_stats_gr
270+
sufficient_stats[[length(sufficient_stats) + 1]] = sufficient_stats_gr
271271
}
272272
return(sufficient_stats)
273273
}
274274
}
275275

276276
# Helper function for computing sufficient statistics for pairwise interactions
277277
compute_pairwise_stats <- function(x, group) {
278-
result <- vector("list", length(unique(group)))
278+
result <- list()
279279

280280
for(g in unique(group)) {
281281
obs <- x[group == g, , drop = FALSE]
282282
# cross-product: gives number of co-occurrences of categories
283-
result[[g]] <- t(obs) %*% obs
283+
result[[length(result) + 1]] <- t(obs) %*% obs
284284
}
285285

286286
result

R/nuts_diagnostics.R

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,16 @@ summarize_nuts_diagnostics <- function(out, nuts_max_depth = 10, verbose = TRUE)
4242
100 * divergence_rate,
4343
total_divergences,
4444
nrow(divergent_mat) * ncol(divergent_mat)
45-
), "Consider increasing the target acceptance rate.")
46-
} else if(divergence_rate > 0) {
47-
message(
48-
sprintf(
49-
"Note: %.3f%% of transitions ended with a divergence (%d of %d).\n",
50-
100 * divergence_rate,
51-
total_divergences,
52-
nrow(divergent_mat) * ncol(divergent_mat)
53-
),
54-
"Check R-hat and effective sample size (ESS) to ensure the chains are\n",
55-
"mixing well."
56-
)
45+
), "Consider increasing the target acceptance rate or change to update_method = ``adaptive-metropolis''.")
46+
} else if (divergence_rate > 0) {
47+
message(sprintf(
48+
"Note: %.3f%% of transitions ended with a divergence (%d of %d).\n",
49+
100 * divergence_rate,
50+
total_divergences,
51+
nrow(divergent_mat) * ncol(divergent_mat)
52+
),
53+
"Check R-hat and effective sample size (ESS) to ensure the chains are\n",
54+
"mixing well.")
5755
}
5856

5957
depth_hit_rate <- max_tree_depth_hits / (nrow(treedepth_mat) * ncol(treedepth_mat))
@@ -84,16 +82,14 @@ summarize_nuts_diagnostics <- function(out, nuts_max_depth = 10, verbose = TRUE)
8482
low_ebfmi_chains <- which(ebfmi_per_chain < 0.3)
8583
min_ebfmi <- min(ebfmi_per_chain)
8684

87-
if(length(low_ebfmi_chains) > 0) {
88-
warning(
89-
sprintf(
90-
"E-BFMI below 0.3 detected in %d chain(s): %s.\n",
91-
length(low_ebfmi_chains),
92-
paste(low_ebfmi_chains, collapse = ", ")
93-
),
94-
"This suggests inefficient momentum resampling in those chains.\n",
95-
"Sampling efficiency may be reduced. Consider longer chains or checking convergence diagnostics."
96-
)
85+
if (length(low_ebfmi_chains) > 0) {
86+
warning(sprintf(
87+
"E-BFMI below 0.3 detected in %d chain(s): %s.\n",
88+
length(low_ebfmi_chains),
89+
paste(low_ebfmi_chains, collapse = ", ")
90+
),
91+
"This suggests inefficient momentum resampling in those chains.\n",
92+
"Sampling efficiency may be reduced. Consider longer chains and check convergence diagnostics.")
9793
}
9894
}
9995

R/sampleMRF.R

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#' in specifying their model.
1414
#'
1515
#' The Blume-Capel option is specifically designed for ordinal variables that
16-
#' have a special type of reference_category category, such as the neutral
16+
#' have a special type of baseline_category category, such as the neutral
1717
#' category in a Likert scale. The Blume-Capel model specifies the following
1818
#' quadratic model for the threshold parameters:
1919
#' \deqn{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}{{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}}
@@ -23,8 +23,8 @@
2323
#' \eqn{\alpha > 0}{\alpha > 0} and decreasing threshold values if
2424
#' \eqn{\alpha <0}{\alpha <0}), if \eqn{\beta < 0}{\beta < 0}, it offers an
2525
#' increasing penalty for responding in a category further away from the
26-
#' reference_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a
27-
#' preference for responding in the reference_category category.
26+
#' baseline_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a
27+
#' preference for responding in the baseline_category category.
2828
#'
2929
#' @param no_states The number of states of the ordinal MRF to be generated.
3030
#'
@@ -53,8 +53,8 @@
5353
#' ``blume-capel''. Binary variables are automatically treated as ``ordinal’’.
5454
#' Defaults to \code{variable_type = "ordinal"}.
5555
#'
56-
#' @param reference_category An integer vector of length \code{no_variables} specifying the
57-
#' reference_category category that is used for the Blume-Capel model (details below).
56+
#' @param baseline_category An integer vector of length \code{no_variables} specifying the
57+
#' baseline_category category that is used for the Blume-Capel model (details below).
5858
#' Can be any integer value between \code{0} and \code{no_categories} (or
5959
#' \code{no_categories[i]}).
6060
#'
@@ -106,7 +106,7 @@
106106
#' interactions = Interactions,
107107
#' thresholds = Thresholds,
108108
#' variable_type = c("b", "b", "o", "b", "o"),
109-
#' reference_category = 2
109+
#' baseline_category = 2
110110
#' )
111111
#'
112112
#' @export
@@ -116,7 +116,7 @@ mrfSampler = function(no_states,
116116
interactions,
117117
thresholds,
118118
variable_type = "ordinal",
119-
reference_category,
119+
baseline_category,
120120
iter = 1e3) {
121121
# Check no_states, no_variables, iter --------------------------------------------
122122
if(no_states <= 0 ||
@@ -187,24 +187,20 @@ mrfSampler = function(no_states,
187187
}
188188
}
189189

190-
# Check the reference_category for Blume-Capel variables ---------------------
190+
# Check the baseline_category for Blume-Capel variables ---------------------
191191
if(any(variable_type == "blume-capel")) {
192-
if(length(reference_category) == 1) {
193-
reference_category = rep(reference_category, no_variables)
192+
if(length(baseline_category) == 1) {
193+
baseline_category = rep(baseline_category, no_variables)
194194
}
195-
if(any(reference_category < 0) || any(abs(reference_category - round(reference_category)) > .Machine$double.eps)) {
196-
stop(paste0(
197-
"For variables ",
198-
which(reference_category < 0),
199-
" ``reference_category'' was either negative or not integer."
200-
))
195+
if(any(baseline_category < 0) || any(abs(baseline_category - round(baseline_category)) > .Machine$double.eps)) {
196+
stop(paste0("For variables ",
197+
which(baseline_category < 0),
198+
" ``baseline_category'' was either negative or not integer."))
201199
}
202-
if(any(reference_category - no_categories > 0)) {
203-
stop(paste0(
204-
"For variables ",
205-
which(reference_category - no_categories > 0),
206-
" the ``reference_category'' category was larger than the maximum category value."
207-
))
200+
if(any(baseline_category - no_categories > 0)) {
201+
stop(paste0("For variables ",
202+
which(baseline_category - no_categories > 0),
203+
" the ``baseline_category'' category was larger than the maximum category value."))
208204
}
209205
}
210206

@@ -347,7 +343,7 @@ mrfSampler = function(no_states,
347343
interactions = interactions,
348344
thresholds = thresholds,
349345
variable_type = variable_type,
350-
reference_category = reference_category,
346+
baseline_category = baseline_category,
351347
iter = iter
352348
)
353349
}

0 commit comments

Comments
 (0)