Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 258 additions & 0 deletions R/convergence_diagnostics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
#' MCMC Convergence Diagnostics for Stan NMA Models
#'
#' These functions provide convenient access to MCMC convergence diagnostics
#' for `stan_nma` objects, leveraging the `bayesplot` package for visualization
#' and Stan's built-in diagnostic capabilities.
#'
#' @name mcmc_diagnostics
#' @param x A `stan_nma` object created by [nma()]
#' @param pars Character vector of parameter names to include in diagnostics.
#' If `NULL` (default), automatically selects a reasonable subset based on
#' the model type. See Details for automatic parameter selection.
#' @param ... Additional arguments passed to the underlying `bayesplot` functions
#'
#' @details
#' **Automatic Parameter Selection:**
#'
#' When `pars = NULL`, the functions automatically select parameters based on
#' the model characteristics:
#' - Always included: treatment effects (`d`), intercept terms
#' - If random effects: heterogeneity parameters (`tau`)
#' - If regression: regression coefficients (`beta`)
#' - If auxiliary parameters: auxiliary effects (`aux`)
#' - Limits to maximum of 12 parameters for readability
#'
#' **Diagnostic Functions:**
#' - `mcmc_trace()`: Trace plots for visual assessment of chain mixing
#' - `mcmc_acf()`: Autocorrelation function plots to assess chain efficiency
#' - `mcmc_rhat()`: R-hat convergence diagnostics (should be < 1.1)
#' - `mcmc_neff()`: Effective sample size diagnostics
#' - `mcmc_diagnostics_plot()`: Combined diagnostic plot with multiple panels
#'
#' @return
#' - `mcmc_trace()`, `mcmc_acf()`: ggplot objects from bayesplot
#' - `mcmc_rhat()`, `mcmc_neff()`: ggplot objects showing diagnostic values
#' - `mcmc_diagnostics_plot()`: Combined ggplot with multiple diagnostic panels
#'
#' @examples
#' \dontrun{
#' # Fit a model
#' fit <- nma(network, trt_effects = "random", ...)
#'
#' # Generate trace plots
#' mcmc_trace(fit)
#'
#' # Check R-hat diagnostics
#' mcmc_rhat(fit)
#'
#' # Comprehensive diagnostic plot
#' mcmc_diagnostics_plot(fit, type = c("trace", "rhat"))
#' }
#'
NULL

#' @rdname mcmc_diagnostics
#' @export
mcmc_trace <- function(x, pars = NULL, ...) {
if (!inherits(x, "stan_nma")) {
abort("x must be a stan_nma object")
}

if (is.null(pars)) {
pars <- get_default_parameters(x)
}

# Validate parameters exist
pars <- validate_parameters(x, pars)

# Extract posterior draws
posterior_array <- rstan::extract(x$stanfit, pars = pars, permuted = FALSE)

# Generate trace plot
bayesplot::mcmc_trace(posterior_array, pars = pars, ...)
}

#' @rdname mcmc_diagnostics
#' @export
mcmc_acf <- function(x, pars = NULL, ...) {
if (!inherits(x, "stan_nma")) {
abort("x must be a stan_nma object")
}

if (is.null(pars)) {
pars <- get_default_parameters(x)
}

# Validate parameters exist
pars <- validate_parameters(x, pars)

# Extract posterior draws
posterior_array <- rstan::extract(x$stanfit, pars = pars, permuted = FALSE)

# Generate ACF plot
bayesplot::mcmc_acf(posterior_array, pars = pars, ...)
}

#' @rdname mcmc_diagnostics
#' @export
mcmc_rhat <- function(x, pars = NULL, ...) {
if (!inherits(x, "stan_nma")) {
abort("x must be a stan_nma object")
}

if (is.null(pars)) {
pars <- get_default_parameters(x)
}

# Validate parameters exist
pars <- validate_parameters(x, pars)

# Extract R-hat values
fit_summary <- rstan::summary(x$stanfit, pars = pars)$summary
rhat_values <- fit_summary[, "Rhat"]

# Generate R-hat plot
bayesplot::mcmc_rhat(rhat_values, ...)
}

#' @rdname mcmc_diagnostics
#' @export
mcmc_neff <- function(x, pars = NULL, ...) {
if (!inherits(x, "stan_nma")) {
abort("x must be a stan_nma object")
}

if (is.null(pars)) {
pars <- get_default_parameters(x)
}

# Validate parameters exist
pars <- validate_parameters(x, pars)

# Extract effective sample size
fit_summary <- rstan::summary(x$stanfit, pars = pars)$summary
neff_values <- fit_summary[, "n_eff"]

# Generate n_eff plot
bayesplot::mcmc_neff(neff_values, ...)
}

#' @rdname mcmc_diagnostics
#' @param type Character vector specifying which diagnostic plots to include.
#' Options: "trace", "acf", "rhat", "neff". Default includes all.
#' @export
mcmc_diagnostics_plot <- function(x, pars = NULL,
type = c("trace", "acf", "rhat", "neff"), ...) {
if (!inherits(x, "stan_nma")) {
abort("x must be a stan_nma object")
}

type <- match.arg(type, several.ok = TRUE)

if (is.null(pars)) {
pars <- get_default_parameters(x)
}

# Validate parameters exist
pars <- validate_parameters(x, pars)

# Generate individual plots
plots <- list()

if ("trace" %in% type) {
plots$trace <- mcmc_trace(x, pars = pars, ...)
}

if ("acf" %in% type) {
plots$acf <- mcmc_acf(x, pars = pars, ...)
}

if ("rhat" %in% type) {
plots$rhat <- mcmc_rhat(x, pars = pars, ...)
}

if ("neff" %in% type) {
plots$neff <- mcmc_neff(x, pars = pars, ...)
}

# Combine plots using patchwork
if (length(plots) == 1) {
return(plots[[1]])
} else {
return(patchwork::wrap_plots(plots, ncol = 2))
}
}

# Internal helper functions

#' Get default parameters for diagnostics based on model type
#' @param x stan_nma object
#' @return character vector of parameter names
#' @keywords internal
get_default_parameters <- function(x) {
all_pars <- x$stanfit@model_pars

# Start with core parameters
default_pars <- character(0)

# Always include treatment effects if present
if ("d" %in% all_pars) {
d_pars <- grep("^d\\[", rownames(rstan::summary(x$stanfit)$summary), value = TRUE)
default_pars <- c(default_pars, head(d_pars, 6)) # Limit to first 6
}

# Include intercept/baseline parameters
intercept_pars <- grep("^(alpha|mu|baseline)", all_pars, value = TRUE)
default_pars <- c(default_pars, head(intercept_pars, 3))

# Include heterogeneity parameters for random effects models
if (x$trt_effects == "random") {
het_pars <- grep("^(tau|sigma)", all_pars, value = TRUE)
default_pars <- c(default_pars, head(het_pars, 2))
}

# Include regression coefficients if regression model
if (!is.null(x$regression)) {
beta_pars <- grep("^beta", all_pars, value = TRUE)
default_pars <- c(default_pars, head(beta_pars, 3))
}

# Include auxiliary parameters if present
aux_pars <- grep("^(aux|delta)", all_pars, value = TRUE)
default_pars <- c(default_pars, head(aux_pars, 2))

# Remove duplicates and limit total
default_pars <- unique(default_pars)
default_pars <- head(default_pars, 12) # Maximum 12 parameters

return(default_pars)
}

#' Validate that requested parameters exist in the model
#' @param x stan_nma object
#' @param pars character vector of parameter names
#' @return validated character vector of parameter names
#' @keywords internal
validate_parameters <- function(x, pars) {
if (length(pars) == 0) {
abort("No parameters specified or available for diagnostics")
}

# Get all available parameter names from summary
all_available <- rownames(rstan::summary(x$stanfit)$summary)

# Check which requested parameters exist
missing_pars <- setdiff(pars, all_available)

if (length(missing_pars) > 0) {
warn(paste("The following parameters were not found in the model and will be ignored:",
paste(missing_pars, collapse = ", ")))
pars <- intersect(pars, all_available)
}

if (length(pars) == 0) {
abort("None of the specified parameters were found in the model")
}

return(pars)
}
110 changes: 110 additions & 0 deletions R/integration.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ add_integration.data.frame <- function(x, ...,

if (ncol(cor) != nx)
abort("Dimensions of correlation matrix `cor` and number of covariates specified in `...` do not match.")

# Detect multicollinearity in user-provided correlation matrix
detect_multicollinearity(cor, x_names)
} else {
abort("Specify a correlation matrix using the `cor` argument.")
}
Expand Down Expand Up @@ -208,6 +211,9 @@ add_integration.data.frame <- function(x, ...,
warn("Adjusted correlation matrix not positive definite; using Matrix::nearPD().")
copula_cor <- as.matrix(Matrix::nearPD(copula_cor, corr = TRUE)$mat)
}

# Detect multicollinearity in adjusted correlation matrix
detect_multicollinearity(copula_cor, x_names)
}

cop <- copula::normalCopula(copula::P2p(copula_cor), dim = nx, dispstr = "un")
Expand Down Expand Up @@ -369,6 +375,9 @@ add_integration.nma_data <- function(x, ...,

diag(ipd_cor) <- 1

# Detect multicollinearity in IPD-derived correlation matrix
detect_multicollinearity(ipd_cor, x_names)

cor <- ipd_cor
}

Expand Down Expand Up @@ -781,3 +790,104 @@ cor_adjust_pearson <- function(X, types) {

return(X)
}

# Internal function to detect multicollinearity in correlation matrices
detect_multicollinearity <- function(cor_matrix, var_names,
high_cor_threshold = 0.95,
cond_num_threshold = 30,
det_threshold = 1e-8) {

# Input validation
if (!is.matrix(cor_matrix) || !is.numeric(cor_matrix)) {
abort("cor_matrix must be a numeric matrix")
}

if (!isSymmetric(cor_matrix)) {
abort("cor_matrix must be symmetric")
}

if (missing(var_names) || length(var_names) != ncol(cor_matrix)) {
abort("var_names must be provided and match the number of columns in cor_matrix")
}

n_vars <- ncol(cor_matrix)

# Skip detection for single variable case
if (n_vars <= 1) {
return(invisible(NULL))
}

issues_found <- list()

# 1. Check for high pairwise correlations
high_cors <- which(abs(cor_matrix) >= high_cor_threshold & upper.tri(cor_matrix, diag = FALSE), arr.ind = TRUE)
if (nrow(high_cors) > 0) {
high_cor_pairs <- apply(high_cors, 1, function(idx) {
paste0(var_names[idx[1]], " & ", var_names[idx[2]],
" (r = ", round(cor_matrix[idx[1], idx[2]], 3), ")")
})
issues_found$high_correlations <- high_cor_pairs
}

# 2. Check condition number
eigenvals <- eigen(cor_matrix, symmetric = TRUE)$values
condition_number <- max(eigenvals) / min(eigenvals[eigenvals > 1e-12])

if (condition_number > cond_num_threshold) {
issues_found$condition_number <- condition_number
}

# 3. Check determinant (near-singularity)
det_val <- det(cor_matrix)
if (det_val < det_threshold) {
issues_found$determinant <- det_val
}

# 4. Check for near-zero eigenvalues
near_zero_eigs <- eigenvals[eigenvals < 1e-8]
if (length(near_zero_eigs) > 0) {
issues_found$eigenvalues <- near_zero_eigs
}

# Issue warnings if problems detected
if (length(issues_found) > 0) {
warning_msgs <- character()

if (!is.null(issues_found$high_correlations)) {
warning_msgs <- c(warning_msgs,
paste("High pairwise correlations detected (|r| >= ", high_cor_threshold, "):\n ",
paste(issues_found$high_correlations, collapse = "\n ")))
}

if (!is.null(issues_found$condition_number)) {
warning_msgs <- c(warning_msgs,
paste("High condition number detected:", round(issues_found$condition_number, 2),
"(threshold:", cond_num_threshold, ")"))
}

if (!is.null(issues_found$determinant)) {
warning_msgs <- c(warning_msgs,
paste("Near-singular matrix detected, determinant =",
format(issues_found$determinant, scientific = TRUE),
"(threshold:", format(det_threshold, scientific = TRUE), ")"))
}

if (!is.null(issues_found$eigenvalues)) {
warning_msgs <- c(warning_msgs,
paste("Matrix has", length(issues_found$eigenvalues),
"near-zero eigenvalue(s), smallest =",
format(min(issues_found$eigenvalues), scientific = TRUE)))
}

# Add remediation advice
warning_msgs <- c(warning_msgs, "",
"Multicollinearity may lead to numerical instability in ML-NMR models.",
"Consider: (1) removing highly correlated variables, (2) using PCA or",
"factor analysis, (3) regularization techniques, or (4) domain knowledge",
"to select the most clinically relevant variables.")

warn(paste(warning_msgs, collapse = "\n"))
}

return(invisible(issues_found))
}
Loading