Skip to content

Commit 42a92b4

Browse files
Change co-occurrence convergence analysis to analysis of binary Markov chain.
Update summary.bgm and print.summary.bgm accordingly.
1 parent 2a42154 commit 42a92b4

File tree

3 files changed

+39
-22
lines changed

3 files changed

+39
-22
lines changed

R/bgms-methods.R

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ summary.bgms <- function(object, ...) {
7070
}
7171

7272
if (!is.null(object$posterior_summary_pairwise_allocations)) {
73-
out$pairwise_allocations <- object$posterior_summary_pairwise_allocations
74-
out$allocations_mean <- object$posterior_mean_allocations
73+
out$allocations <- object$posterior_summary_pairwise_allocations
74+
out$mean_allocations <- object$posterior_mean_allocations
75+
out$mode_allocations <- object$posterior_mode_allocations
7576
out$num_blocks <- object$posterior_num_blocks
7677
}
7778

@@ -118,23 +119,24 @@ print.summary.bgms <- function(x, digits = 3, ...) {
118119
ind <- head(x$indicator, 6)
119120
ind[] <- lapply(ind, function(col) ifelse(is.na(col), "", round(col, digits)))
120121
print(ind)
121-
if (nrow(x$indicator) > 6)
122-
cat("... (use `summary(fit)$indicator` to see full output)\n")
122+
if (nrow(x$indicator) > 6) cat("... (use `summary(fit)$indicator` to see full output)\n")
123123
cat("Note: NA values are suppressed in the print table. They occur when an indicator\n")
124124
cat("was constant (all 0 or all 1) across all iterations, so sd/mcse/n_eff/Rhat\n")
125125
cat("are undefined; `summary(fit)$indicator` still contains the NA values.\n\n")
126126
}
127127

128-
if (!is.null(x$pairwise_allocations)) {
128+
if (!is.null(x$allocations)) {
129129
cat("Pairwise node co-clustering proportion:\n")
130-
print(round(head(x$pairwise_allocations, 6), digits = digits))
131-
if (nrow(x$pairwise_allocations) > 6) cat("... (use `summary(fit)$allocations` to see full output)\n")
130+
print(round(head(x$allocations, 6), digits = digits))
131+
if (nrow(x$allocations) > 6) cat("... (use `summary(fit)$allocations` to see full output)\n")
132132
cat("\n")
133133
}
134134

135-
if (!is.null(x$allocations_mean)) {
136-
cat("Mean posterior node allocation vector :\n")
137-
print(round(head(x$allocations_mean, 6), digits = digits))
135+
if (!is.null(x$mean_allocations)) {
136+
cat("Mean posterior node allocation vector:\n")
137+
print(round(head(x$mean_allocations, 6), digits = digits))
138+
cat("Mode posterior node allocation vector:\n")
139+
print(round(head(x$mode_allocations, 6), digits = digits))
138140
cat("\n")
139141
}
140142

@@ -185,7 +187,6 @@ coef.bgms <- function(object, ...) {
185187
}
186188

187189

188-
189190
.warning_issued <- FALSE
190191
warning_once <- function(msg) {
191192
if (!.warning_issued) {

R/mcmc_summary.R

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ summarize_alloc_pairs = function(allocations, node_names = NULL) {
249249
Pairs = t(combn(seq_len(no_variables), 2))
250250
nparam = nrow(Pairs)
251251

252-
result = matrix(NA, nparam, 5)
253-
colnames(result) = c("mean", "sd", "mcse", "n_eff", "Rhat")
252+
result = matrix(NA, nparam, 9)
253+
colnames(result) = c("mean", "sd", "mcse", "n0->0", "n0->1", "n1->0", "n1->1", "n_eff", "Rhat")
254254

255255
# helper to construct a "time-series"
256256
get_draws_pair = function(i, j) {
@@ -265,16 +265,32 @@ summarize_alloc_pairs = function(allocations, node_names = NULL) {
265265
for (p in seq_len(nparam)) {
266266
i = Pairs[p, 1]; j = Pairs[p, 2]
267267
draws = get_draws_pair(i, j)
268-
vec = as.vector(draws)
269-
phat = mean(vec)
270-
sdev = sd(vec)
271268

272-
est = compute_rhat_ess(draws)
273-
n_eff = as.numeric(est$ess)
274-
Rhat = as.numeric(est$rhat)
269+
vec = as.vector(draws)
270+
T = length(vec)
271+
g_next = vec[-1]
272+
g_curr = vec[-T]
273+
274+
p_hat = mean(vec)
275+
sd = sqrt(p_hat * (1 - p_hat))
276+
n00 = sum(g_curr == 0 & g_next == 0)
277+
n01 = sum(g_curr == 0 & g_next == 1)
278+
n10 = sum(g_curr == 1 & g_next == 0)
279+
n11 = sum(g_curr == 1 & g_next == 1)
280+
281+
if (n01 + n10 == 0) {
282+
n_eff = mcse = R = NA_real_
283+
} else {
284+
a = n01 / (n00 + n01)
285+
b = n10 / (n10 + n11)
286+
tau_int = (2 - (a + b)) / (a + b)
287+
n_eff = T / tau_int
288+
mcse = sd / sqrt(n_eff)
289+
est = compute_rhat_ess(draws)
290+
R = est$rhat
291+
}
275292

276-
mcse = if (is.finite(n_eff) && n_eff > 0) sdev / sqrt(n_eff) else NA
277-
result[p, ] = c(phat, sdev, mcse, n_eff, Rhat)
293+
result[p, ] = c(p_hat, sd, mcse, n00, n01, n10, n11, n_eff, R)
278294
}
279295
if (is.null(node_names)) {
280296
rn = paste0(Pairs[,1], "-", Pairs[,2])

R/output_utils.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ prepare_output_bgm = function(
136136
allocations = lapply(out, `[[`, "allocations"),
137137
node_names = data_columnnames
138138
)
139-
results$posterior_coclustering_matrix = sbm_convergence$co_occur_matrix
139+
results$posterior_mean_coclustering_matrix = sbm_convergence$co_occur_matrix
140140
# calculate the estimated clustering and block probabilities
141141
sbm_summary = posterior_summary_SBM(allocations = lapply(out, `[[`, "allocations"),
142142
arguments = arguments) # check if only arguments would work

0 commit comments

Comments
 (0)