Skip to content

Commit fc1dcdc

Browse files
Fixes bgmCompare output handling.
Fixes summary, print, and coef functions.
1 parent 72da0f6 commit fc1dcdc

File tree

4 files changed

+360
-204
lines changed

4 files changed

+360
-204
lines changed

R/bgmcompare-methods.r

Lines changed: 95 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -145,26 +145,54 @@ print.summary.bgmCompare = function(x, digits = 3, ...) {
145145
print(ind, row.names = FALSE)
146146
if (nrow(x$indicator) > 6)
147147
cat("... (use `summary(fit)$indicator` to see full output)\n")
148-
cat("\n")
149148
cat("Note: NA values are suppressed in the print table. They occur when an indicator\n")
150149
cat("was constant (all 0 or all 1) across all iterations, so sd/mcse/n_eff/Rhat\n")
151150
cat("are undefined; `summary(fit)$indicator` still contains the NA values.\n\n")
152151
}
153152

154-
155153
if (!is.null(x$main_diff)) {
156154
cat("Group differences (main effects):\n")
157-
print_df(x$main_diff, digits)
155+
156+
maind <- head(x$main_diff, 6)
157+
158+
# Only round numeric columns
159+
is_num <- vapply(maind, is.numeric, logical(1))
160+
maind[is_num] <- lapply(maind[is_num],
161+
function(col) ifelse(is.na(col), "", round(col, digits)))
162+
163+
print(maind, row.names = FALSE)
164+
158165
if (nrow(x$main_diff) > 6)
159166
cat("... (use `summary(fit)$main_diff` to see full output)\n")
167+
168+
if (!is.null(x$indicator)) {
169+
cat("Note: NA values are suppressed in the print table. They occur here when an\n")
170+
cat("indicator was zero across all iterations, so mcse/n_eff/Rhat are undefined;\n")
171+
cat("`summary(fit)$main_diff` still contains the NA values.\n")
172+
}
160173
cat("\n")
161174
}
162175

163176
if (!is.null(x$pairwise_diff)) {
164177
cat("Group differences (pairwise effects):\n")
165-
print_df(x$pairwise_diff, digits)
178+
179+
pairwised <- head(x$pairwise_diff, 6)
180+
181+
# Only round numeric columns
182+
is_num <- vapply(pairwised, is.numeric, logical(1))
183+
pairwised[is_num] <- lapply(pairwised[is_num],
184+
function(col) ifelse(is.na(col), "", round(col, digits)))
185+
186+
print(pairwised, row.names = FALSE)
187+
166188
if (nrow(x$pairwise_diff) > 6)
167189
cat("... (use `summary(fit)$pairwise_diff` to see full output)\n")
190+
191+
if (!is.null(x$indicator)) {
192+
cat("Note: NA values are suppressed in the print table. They occur here when an\n")
193+
cat("indicator was zero across all iterations, so mcse/n_eff/Rhat are undefined;\n")
194+
cat("`summary(fit)$pairwise_diff` still contains the NA values.\n")
195+
}
168196
cat("\n")
169197
}
170198

@@ -204,29 +232,36 @@ coef.bgmCompare <- function(object, ...) {
204232
is_ordinal <- as.logical(args$is_ordinal_variable)
205233
num_groups <- as.integer(args$num_groups)
206234
num_variables <- as.integer(args$num_variables)
207-
projection <- args$projection # matrix [num_groups x (num_groups-1)]
235+
projection <- args$projection # [num_groups x (num_groups-1)]
208236

209-
# helper: combine chains into array3d [iter, chain, param]
237+
# ---- helper: combine chains into [iter, chain, param], robust to vectors/1-col
210238
to_array3d <- function(xlist) {
211-
nchains <- length(xlist)
212-
niter <- nrow(xlist[[1]])
213-
nparam <- ncol(xlist[[1]])
214-
arr <- array(NA_real_, dim = c(niter, nchains, nparam))
215-
for (c in seq_len(nchains)) arr[, c, ] <- xlist[[c]]
239+
if (is.null(xlist)) return(NULL)
240+
stopifnot(length(xlist) >= 1)
241+
mats <- lapply(xlist, function(x) {
242+
m <- as.matrix(x)
243+
if (is.null(dim(m))) m <- matrix(m, ncol = 1L)
244+
m
245+
})
246+
niter <- nrow(mats[[1]])
247+
nparam <- ncol(mats[[1]])
248+
arr <- array(NA_real_, dim = c(niter, length(mats), nparam))
249+
for (c in seq_along(mats)) arr[, c, ] <- mats[[c]]
216250
arr
217251
}
218252

219253
# ============================================================
220254
# ---- main effects ----
221255
array3d_main <- to_array3d(object$raw_samples$main)
222-
mean_main <- apply(array3d_main, 3, mean)
256+
stopifnot(!is.null(array3d_main))
257+
mean_main <- apply(array3d_main, 3, mean)
223258

224-
num_main <- length(mean_main) / num_groups
225-
main_mat <- matrix(mean_main,
226-
nrow = num_main, ncol = num_groups,
227-
byrow = FALSE)
259+
stopifnot(length(mean_main) %% num_groups == 0L)
260+
num_main <- as.integer(length(mean_main) / num_groups)
228261

229-
# row names
262+
main_mat <- matrix(mean_main, nrow = num_main, ncol = num_groups, byrow = FALSE)
263+
264+
# row names in sampler row order
230265
rownames(main_mat) <- unlist(lapply(seq_len(num_variables), function(v) {
231266
if (is_ordinal[v]) {
232267
paste0(var_names[v], "(c", seq_len(num_categories[v]), ")")
@@ -235,15 +270,13 @@ coef.bgmCompare <- function(object, ...) {
235270
paste0(var_names[v], "(quadratic)"))
236271
}
237272
}))
273+
colnames(main_mat) <- c("baseline", paste0("diff", seq_len(num_groups - 1L)))
238274

239-
# column names: baseline + diffs
240-
colnames(main_mat) <- c("baseline", paste0("diff", seq_len(num_groups - 1)))
241-
242-
# compute group effects
275+
# group-specific main effects: baseline + P %*% diffs
243276
main_effects_groups <- matrix(NA_real_, nrow = num_main, ncol = num_groups)
244277
for (r in seq_len(num_main)) {
245278
baseline <- main_mat[r, 1]
246-
diffs <- main_mat[r, -1]
279+
diffs <- main_mat[r, -1, drop = TRUE]
247280
main_effects_groups[r, ] <- baseline + as.vector(projection %*% diffs)
248281
}
249282
rownames(main_effects_groups) <- rownames(main_mat)
@@ -252,51 +285,63 @@ coef.bgmCompare <- function(object, ...) {
252285
# ============================================================
253286
# ---- pairwise effects ----
254287
array3d_pair <- to_array3d(object$raw_samples$pairwise)
255-
mean_pair <- apply(array3d_pair, 3, mean)
288+
stopifnot(!is.null(array3d_pair))
289+
mean_pair <- apply(array3d_pair, 3, mean)
290+
291+
stopifnot(length(mean_pair) %% num_groups == 0L)
292+
num_pair <- as.integer(length(mean_pair) / num_groups)
256293

257-
num_pair <- length(mean_pair) / num_groups
258-
pairwise_mat <- matrix(mean_pair,
259-
nrow = num_pair, ncol = num_groups,
260-
byrow = FALSE)
294+
pairwise_mat <- matrix(mean_pair, nrow = num_pair, ncol = num_groups, byrow = FALSE)
261295

262-
# row names
296+
# row names in sampler row order (upper-tri i<j)
263297
pair_names <- character()
264-
for (i in 1:(num_variables - 1)) {
265-
for (j in (i + 1):num_variables) {
266-
pair_names <- c(pair_names, paste0(var_names[i], "-", var_names[j]))
298+
if (num_variables >= 2L) {
299+
for (i in 1L:(num_variables - 1L)) {
300+
for (j in (i + 1L):num_variables) {
301+
pair_names <- c(pair_names, paste0(var_names[i], "-", var_names[j]))
302+
}
267303
}
268304
}
269305
rownames(pairwise_mat) <- pair_names
270-
colnames(pairwise_mat) <- c("baseline", paste0("diff", seq_len(num_groups - 1)))
306+
colnames(pairwise_mat) <- c("baseline", paste0("diff", seq_len(num_groups - 1L)))
271307

272-
# compute group effects
308+
# group-specific pairwise effects
273309
pairwise_effects_groups <- matrix(NA_real_, nrow = num_pair, ncol = num_groups)
274310
for (r in seq_len(num_pair)) {
275311
baseline <- pairwise_mat[r, 1]
276-
diffs <- pairwise_mat[r, -1]
312+
diffs <- pairwise_mat[r, -1, drop = TRUE]
277313
pairwise_effects_groups[r, ] <- baseline + as.vector(projection %*% diffs)
278314
}
279315
rownames(pairwise_effects_groups) <- rownames(pairwise_mat)
280316
colnames(pairwise_effects_groups) <- paste0("group", seq_len(num_groups))
281317

282318
# ============================================================
283-
# ---- indicators ----
319+
# ---- indicators (present only if selection was used) ----
320+
indicators <- NULL
284321
array3d_ind <- to_array3d(object$raw_samples$indicator)
285-
mean_ind <- apply(array3d_ind, 3, mean)
286-
287-
indicators <- matrix(0, nrow = num_variables, ncol = num_variables,
288-
dimnames = list(var_names, var_names))
289-
290-
diag(indicators) <- mean_ind[seq_len(num_variables)]
291-
292-
counter <- num_variables + 1
293-
for (i in 1:(num_variables - 1)) {
294-
for (j in (i + 1):num_variables) {
295-
val <- mean_ind[counter]
296-
indicators[i, j] <- val
297-
indicators[j, i] <- val
298-
counter <- counter + 1
322+
if (!is.null(array3d_ind)) {
323+
mean_ind <- apply(array3d_ind, 3, mean)
324+
325+
# reconstruct VxV matrix using the sampler’s interleaved order:
326+
# (1,1),(1,2),...,(1,V),(2,2),...,(2,V),...,(V,V)
327+
V <- num_variables
328+
stopifnot(length(mean_ind) == V * (V + 1L) / 2L)
329+
330+
ind_mat <- matrix(0, nrow = V, ncol = V,
331+
dimnames = list(var_names, var_names))
332+
pos <- 1L
333+
for (i in seq_len(V)) {
334+
# diagonal (main indicator)
335+
ind_mat[i, i] <- mean_ind[pos]; pos <- pos + 1L
336+
if (i < V) {
337+
for (j in (i + 1L):V) {
338+
val <- mean_ind[pos]; pos <- pos + 1L
339+
ind_mat[i, j] <- val
340+
ind_mat[j, i] <- val
341+
}
342+
}
299343
}
344+
indicators <- ind_mat
300345
}
301346

302347
# ============================================================
@@ -308,4 +353,4 @@ coef.bgmCompare <- function(object, ...) {
308353
pairwise_effects_groups = pairwise_effects_groups,
309354
indicators = indicators
310355
)
311-
}
356+
}

R/bgms-methods.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,16 @@ print.summary.bgms <- function(x, digits = 3, ...) {
100100

101101
if (!is.null(x$pairwise)) {
102102
cat("Pairwise interactions:\n")
103-
print(round(head(x$pairwise, 6), digits = digits))
103+
pair <- head(x$pairwise, 6)
104+
pair[] <- lapply(pair, function(col) ifelse(is.na(col), "", round(col, digits)))
105+
print(pair)
106+
#print(round(head(x$pairwise, 6), digits = digits))
104107
if (nrow(x$pairwise) > 6) cat("... (use `summary(fit)$pairwise` to see full output)\n")
108+
if (!is.null(x$indicator)) {
109+
cat("Note: NA values are suppressed in the print table. They occur here when an \n")
110+
cat("indicator was zero across all iterations, so mcse/n_eff/Rhat are undefined;\n")
111+
cat("`summary(fit)$pairwise` still contains the NA values.\n")
112+
}
105113
cat("\n")
106114
}
107115

@@ -112,7 +120,6 @@ print.summary.bgms <- function(x, digits = 3, ...) {
112120
print(ind, row.names = FALSE)
113121
if (nrow(x$indicator) > 6)
114122
cat("... (use `summary(fit)$indicator` to see full output)\n")
115-
cat("\n")
116123
cat("Note: NA values are suppressed in the print table. They occur when an indicator\n")
117124
cat("was constant (all 0 or all 1) across all iterations, so sd/mcse/n_eff/Rhat\n")
118125
cat("are undefined; `summary(fit)$indicator` still contains the NA values.\n\n")

0 commit comments

Comments
 (0)