Skip to content

Commit 72c22dd

Browse files
Fix issue with spurious error messages due to output handling after user interrupt.
1 parent f0a9288 commit 72c22dd

File tree

2 files changed

+74
-8
lines changed

2 files changed

+74
-8
lines changed

R/bgm.R

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,41 @@ bgm = function(
588588
nThreads = cores, seed = seed, progress_type = progress_type
589589
)
590590

591+
592+
userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
593+
if (userInterrupt) {
594+
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
595+
# Try to prepare output, but catch any errors
596+
output <- tryCatch(
597+
prepare_output_bgm(
598+
out = out, x = x, num_categories = num_categories, iter = iter,
599+
data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
600+
is_ordinal_variable = variable_bool,
601+
warmup = warmup, pairwise_scale = pairwise_scale,
602+
main_alpha = main_alpha, main_beta = main_beta,
603+
na_action = na_action, na_impute = na_impute,
604+
edge_selection = edge_selection, edge_prior = edge_prior, inclusion_probability = inclusion_probability,
605+
beta_bernoulli_alpha = beta_bernoulli_alpha, beta_bernoulli_beta = beta_bernoulli_beta,
606+
dirichlet_alpha = dirichlet_alpha, lambda = lambda,
607+
variable_type = variable_type,
608+
update_method = update_method,
609+
target_accept = target_accept,
610+
hmc_num_leapfrogs = hmc_num_leapfrogs,
611+
nuts_max_depth = nuts_max_depth,
612+
learn_mass_matrix = learn_mass_matrix,
613+
num_chains = chains
614+
),
615+
error = function(e) {
616+
list(partial = out, error = conditionMessage(e))
617+
},
618+
warning = function(w) {
619+
# still salvage what we can
620+
list(partial = out, warning = conditionMessage(w))
621+
}
622+
)
623+
return(output)
624+
}
625+
591626
# Main output handler in the wrapper function
592627
output = prepare_output_bgm (
593628
out = out, x = x, num_categories = num_categories, iter = iter,
@@ -634,9 +669,5 @@ bgm = function(
634669
}
635670
}
636671

637-
userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
638-
if (userInterrupt)
639-
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
640-
641672
return(output)
642673
}

R/bgmCompare.R

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,45 @@ bgmCompare = function(
514514
progress_type = progress_type
515515
)
516516

517+
userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
518+
if (userInterrupt) {
519+
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
520+
output <- tryCatch(
521+
prepare_output_bgmCompare(
522+
out = out,
523+
observations = observations,
524+
num_categories = num_categories,
525+
is_ordinal_variable = ordinal_variable,
526+
num_groups = num_groups,
527+
iter = iter,
528+
warmup = warmup,
529+
main_effect_indices = main_effect_indices,
530+
pairwise_effect_indices = pairwise_effect_indices,
531+
data_columnnames = if (is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
532+
difference_selection = difference_selection,
533+
difference_prior = difference_prior,
534+
difference_selection_alpha = beta_bernoulli_alpha,
535+
difference_selection_beta = beta_bernoulli_beta,
536+
pairwise_scale = pairwise_scale,
537+
difference_scale = difference_scale,
538+
update_method = update_method,
539+
target_accept = target_accept,
540+
nuts_max_depth = nuts_max_depth,
541+
hmc_num_leapfrogs = hmc_num_leapfrogs,
542+
learn_mass_matrix = learn_mass_matrix,
543+
num_chains = chains,
544+
projection = projection
545+
),
546+
error = function(e) {
547+
list(partial = out, error = conditionMessage(e))
548+
},
549+
warning = function(w) {
550+
list(partial = out, warning = conditionMessage(w))
551+
}
552+
)
553+
return(output)
554+
}
555+
517556
# Main output handler in the wrapper function
518557
output = prepare_output_bgmCompare(
519558
out = out,
@@ -545,9 +584,5 @@ bgmCompare = function(
545584
output$nuts_diag = nuts_diag
546585
}
547586

548-
userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
549-
if (userInterrupt)
550-
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
551-
552587
return(output)
553588
}

0 commit comments

Comments
 (0)