diff --git a/.Rbuildignore b/.Rbuildignore index 3b443361..692b5609 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -1,14 +1,35 @@ -^renv$ -^renv\.lock$ +# RStudio / IDE ^.*\.Rproj$ ^\.Rproj\.user$ +^\.vscode$ -^Readme.Rmd$ -^\.github$ -^_pkgdown\.yml$ +# renv +^renv$ +^renv\.lock$ + +# pkgdown / docs ^docs$ ^pkgdown$ -^vignettes/introduction_cache +^_pkgdown\.yml$ +^Readme\.Rmd$ +^vignettes/introduction_cache$ + +# GitHub / CI +^\.github$ + +# R CMD build artifacts ^doc$ ^Meta$ -^\.vscode$ + +# Development helpers +^dev$ + +# ---- C/C++ build artifacts (REQUIRED) ---- +^src/.*\.o$ +^src/.*\.so$ +^src/.*\.dll$ + +# ---- Generated build files ---- +^src/Makevars$ +^src/Makevars\.win$ +^src/sources\.mk$ diff --git a/.gitignore b/.gitignore index 65196573..2ca937ce 100644 --- a/.gitignore +++ b/.gitignore @@ -5,9 +5,13 @@ src/*.o src/*.so src/*.dll +src/**/*.o +src/**/*.so +src/**/*.dll .DS_Store /doc/ /Meta/ src/Makevars src/Makevars.win +src/sources.mk docs/* diff --git a/DESCRIPTION b/DESCRIPTION index ac544206..5f8cb2a0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,7 +2,7 @@ Package: bgms Type: Package Title: Bayesian Analysis of Networks of Binary and/or Ordinal Variables Version: 0.1.6.2 -Date: 2025-12-02 +Date: 2025-12-29 Authors@R: c( person("Maarten", "Marsman", , "m.marsman@uva.nl", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-5309-7502")), diff --git a/NEWS.md b/NEWS.md index 3ffa9fa5..69bfde65 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,11 +6,14 @@ ## Other changes -* reparameterized the Blume-capel model to use (score-baseline) instead of score for mrfSampler() and bgm(). +* reparameterized the Blume-capel model to use (score-baseline) instead of score. +* implemented a new way to compute the denominators and probabilities. This made their computation both faster and more stable. +* refactored c++ code for better maintainability. +* removed the prepared_data field from bgm objects. ## Bug fixes -* Fixed numerical problems with Blume-Capel variables using HMC and NUTS for bgm(). +* fixed numerical problems with Blume-Capel variables using HMC and NUTS. # bgms 0.1.6.1 @@ -22,9 +25,9 @@ ## Bug fixes -* Fixed a problem with warmup scheduling for adaptive-metropolis in bgmCompare() -* Fixed stability problems with parallel sampling for bgm() -* Fixed spurious output errors printing to console after user interrupt. +* fixed a problem with warmup scheduling for adaptive-metropolis in bgmCompare() +* fixed stability problems with parallel sampling for bgm() +* fixed spurious output errors printing to console after user interrupt. # bgms 0.1.6.0 diff --git a/R/RcppExports.R b/R/RcppExports.R index 6e04e6f1..e312fdbe 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -9,24 +9,16 @@ run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_ .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) } -get_explog_switch <- function() { - .Call(`_bgms_get_explog_switch`) -} - -rcpp_ieee754_exp <- function(x) { - .Call(`_bgms_rcpp_ieee754_exp`, x) -} - -rcpp_ieee754_log <- function(x) { - .Call(`_bgms_rcpp_ieee754_log`, x) -} - sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, iter) { .Call(`_bgms_sample_omrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, iter) } -sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter) { - .Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter) +sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter) { + .Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter) +} + +sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) { + .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) } compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { diff --git a/R/bgm.R b/R/bgm.R index cd421546..ad4c3a01 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -33,7 +33,7 @@ #' Assume a baseline category (e.g., a “neutral” response) and score responses #' by distance from this baseline. Category thresholds are modeled as: #' -#' \deqn{\mu_{c} = \alpha \cdot c + \beta \cdot (c - b)^2} +#' \deqn{\mu_{c} = \alpha \cdot (c-b) + \beta \cdot (c - b)^2} #' #' where: #' \itemize{ @@ -48,7 +48,8 @@ #' } #' \item \eqn{b}: baseline category #' } -#' +#' Accordingly, pairwise interactions between Blume-Capel variables are modeled +#' in terms of \eqn{c-b} scores. #' #' @section Edge Selection: #' When \code{edge_selection = TRUE}, the function performs Bayesian variable @@ -559,8 +560,9 @@ bgm = function( # Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE) bc_vars = which(!variable_bool) for(i in bc_vars) { - blume_capel_stats[1, i] = sum(x[, i]) - blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i])^2) + blume_capel_stats[1, i] = sum(x[, i] - baseline_category[i]) + blume_capel_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2) + x[, i] = x[, i] - baseline_category[i] } } pairwise_stats = t(x) %*% x @@ -626,7 +628,6 @@ bgm = function( nThreads = cores, seed = seed, progress_type = progress_type ) - userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) if(userInterrupt) { warning("Stopped sampling after user interrupt, results are likely uninterpretable.") diff --git a/R/bgmCompare.R b/R/bgmCompare.R index 6478ff91..ecd39680 100644 --- a/R/bgmCompare.R +++ b/R/bgmCompare.R @@ -321,7 +321,7 @@ bgmCompare = function( } else if(update_method == "hamiltonian-mc") { target_accept = 0.65 } else if(update_method == "nuts") { - target_accept = 0.80 + target_accept = 0.65 } } @@ -414,13 +414,15 @@ bgmCompare = function( blume_capel_stats = compute_blume_capel_stats( x, baseline_category, ordinal_variable, group ) + for (i in which(!ordinal_variable)) { + x[, i] = x[, i] - baseline_category[i] + } # Compute sufficient statistics for pairwise interactions pairwise_stats = compute_pairwise_stats( x, group ) - # Index vector used to sample interactions in a random order ----------------- Index = matrix(0, nrow = num_interactions, ncol = 3) counter = 0 @@ -490,7 +492,6 @@ bgmCompare = function( seed <- as.integer(seed) - # Call the Rcpp function out = run_bgmCompare_parallel( observations = observations, diff --git a/R/data_utils.R b/R/data_utils.R index 6559360e..6756704a 100644 --- a/R/data_utils.R +++ b/R/data_utils.R @@ -243,7 +243,7 @@ compute_counts_per_category = function(x, num_categories, group = NULL) { counts_per_category_gr[category, variable] = sum(x[group == g, variable] == category) } } - counts_per_category[[g]] = counts_per_category_gr + counts_per_category[[length(counts_per_category) + 1]] = counts_per_category_gr } return(counts_per_category) } @@ -253,9 +253,9 @@ compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, gro if(is.null(group)) { # One-group design sufficient_stats = matrix(0, nrow = 2, ncol = ncol(x)) bc_vars = which(!ordinal_variable) - for(i in bc_vars) { - sufficient_stats[1, i] = sum(x[, i]) - sufficient_stats[2, i] = sum((x[, i] - baseline_category[i])^2) + for (i in bc_vars) { + sufficient_stats[1, i] = sum(x[, i] - baseline_category[i]) + sufficient_stats[2, i] = sum((x[, i] - baseline_category[i]) ^ 2) } return(sufficient_stats) } else { # Multi-group design @@ -263,11 +263,11 @@ compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, gro for(g in unique(group)) { sufficient_stats_gr = matrix(0, nrow = 2, ncol = ncol(x)) bc_vars = which(!ordinal_variable) - for(i in bc_vars) { - sufficient_stats_gr[1, i] = sum(x[group == g, i]) - sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i])^2) + for (i in bc_vars) { + sufficient_stats_gr[1, i] = sum(x[group == g, i] - baseline_category[i]) + sufficient_stats_gr[2, i] = sum((x[group == g, i] - baseline_category[i]) ^ 2) } - sufficient_stats[[g]] = sufficient_stats_gr + sufficient_stats[[length(sufficient_stats) + 1]] = sufficient_stats_gr } return(sufficient_stats) } @@ -275,12 +275,12 @@ compute_blume_capel_stats = function(x, baseline_category, ordinal_variable, gro # Helper function for computing sufficient statistics for pairwise interactions compute_pairwise_stats <- function(x, group) { - result <- vector("list", length(unique(group))) + result <- list() for(g in unique(group)) { obs <- x[group == g, , drop = FALSE] # cross-product: gives number of co-occurrences of categories - result[[g]] <- t(obs) %*% obs + result[[length(result) + 1]] <- t(obs) %*% obs } result diff --git a/R/nuts_diagnostics.R b/R/nuts_diagnostics.R index dfb9e61a..ba477460 100644 --- a/R/nuts_diagnostics.R +++ b/R/nuts_diagnostics.R @@ -42,18 +42,16 @@ summarize_nuts_diagnostics <- function(out, nuts_max_depth = 10, verbose = TRUE) 100 * divergence_rate, total_divergences, nrow(divergent_mat) * ncol(divergent_mat) - ), "Consider increasing the target acceptance rate.") - } else if(divergence_rate > 0) { - message( - sprintf( - "Note: %.3f%% of transitions ended with a divergence (%d of %d).\n", - 100 * divergence_rate, - total_divergences, - nrow(divergent_mat) * ncol(divergent_mat) - ), - "Check R-hat and effective sample size (ESS) to ensure the chains are\n", - "mixing well." - ) + ), "Consider increasing the target acceptance rate or change to update_method = ``adaptive-metropolis''.") + } else if (divergence_rate > 0) { + message(sprintf( + "Note: %.3f%% of transitions ended with a divergence (%d of %d).\n", + 100 * divergence_rate, + total_divergences, + nrow(divergent_mat) * ncol(divergent_mat) + ), + "Check R-hat and effective sample size (ESS) to ensure the chains are\n", + "mixing well.") } depth_hit_rate <- max_tree_depth_hits / (nrow(treedepth_mat) * ncol(treedepth_mat)) @@ -84,16 +82,14 @@ summarize_nuts_diagnostics <- function(out, nuts_max_depth = 10, verbose = TRUE) low_ebfmi_chains <- which(ebfmi_per_chain < 0.3) min_ebfmi <- min(ebfmi_per_chain) - if(length(low_ebfmi_chains) > 0) { - warning( - sprintf( - "E-BFMI below 0.3 detected in %d chain(s): %s.\n", - length(low_ebfmi_chains), - paste(low_ebfmi_chains, collapse = ", ") - ), - "This suggests inefficient momentum resampling in those chains.\n", - "Sampling efficiency may be reduced. Consider longer chains or checking convergence diagnostics." - ) + if (length(low_ebfmi_chains) > 0) { + warning(sprintf( + "E-BFMI below 0.3 detected in %d chain(s): %s.\n", + length(low_ebfmi_chains), + paste(low_ebfmi_chains, collapse = ", ") + ), + "This suggests inefficient momentum resampling in those chains.\n", + "Sampling efficiency may be reduced. Consider longer chains and check convergence diagnostics.") } } diff --git a/R/output_utils.R b/R/output_utils.R index 264b7e2d..a959984c 100644 --- a/R/output_utils.R +++ b/R/output_utils.R @@ -8,7 +8,6 @@ prepare_output_bgm = function( nuts_max_depth, learn_mass_matrix, num_chains ) { arguments = list( - prepared_data = x, num_variables = ncol(x), num_cases = nrow(x), na_impute = na_impute, @@ -291,7 +290,6 @@ prepare_output_bgmCompare = function( num_variables = ncol(observations) arguments = list( - prepared_data = observations, num_variables = num_variables, num_cases = nrow(observations), iter = iter, diff --git a/R/sampleMRF.R b/R/sampleMRF.R index 487c627e..caacdf37 100644 --- a/R/sampleMRF.R +++ b/R/sampleMRF.R @@ -13,7 +13,7 @@ #' in specifying their model. #' #' The Blume-Capel option is specifically designed for ordinal variables that -#' have a special type of reference_category category, such as the neutral +#' have a special type of baseline_category category, such as the neutral #' category in a Likert scale. The Blume-Capel model specifies the following #' quadratic model for the threshold parameters: #' \deqn{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}{{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}} @@ -23,8 +23,8 @@ #' \eqn{\alpha > 0}{\alpha > 0} and decreasing threshold values if #' \eqn{\alpha <0}{\alpha <0}), if \eqn{\beta < 0}{\beta < 0}, it offers an #' increasing penalty for responding in a category further away from the -#' reference_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a -#' preference for responding in the reference_category category. +#' baseline_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a +#' preference for responding in the baseline_category category. #' #' @param no_states The number of states of the ordinal MRF to be generated. #' @@ -53,8 +53,8 @@ #' ``blume-capel''. Binary variables are automatically treated as ``ordinal’’. #' Defaults to \code{variable_type = "ordinal"}. #' -#' @param reference_category An integer vector of length \code{no_variables} specifying the -#' reference_category category that is used for the Blume-Capel model (details below). +#' @param baseline_category An integer vector of length \code{no_variables} specifying the +#' baseline_category category that is used for the Blume-Capel model (details below). #' Can be any integer value between \code{0} and \code{no_categories} (or #' \code{no_categories[i]}). #' @@ -106,7 +106,7 @@ #' interactions = Interactions, #' thresholds = Thresholds, #' variable_type = c("b", "b", "o", "b", "o"), -#' reference_category = 2 +#' baseline_category = 2 #' ) #' #' @export @@ -116,7 +116,7 @@ mrfSampler = function(no_states, interactions, thresholds, variable_type = "ordinal", - reference_category, + baseline_category, iter = 1e3) { # Check no_states, no_variables, iter -------------------------------------------- if(no_states <= 0 || @@ -187,24 +187,20 @@ mrfSampler = function(no_states, } } - # Check the reference_category for Blume-Capel variables --------------------- + # Check the baseline_category for Blume-Capel variables --------------------- if(any(variable_type == "blume-capel")) { - if(length(reference_category) == 1) { - reference_category = rep(reference_category, no_variables) + if(length(baseline_category) == 1) { + baseline_category = rep(baseline_category, no_variables) } - if(any(reference_category < 0) || any(abs(reference_category - round(reference_category)) > .Machine$double.eps)) { - stop(paste0( - "For variables ", - which(reference_category < 0), - " ``reference_category'' was either negative or not integer." - )) + if(any(baseline_category < 0) || any(abs(baseline_category - round(baseline_category)) > .Machine$double.eps)) { + stop(paste0("For variables ", + which(baseline_category < 0), + " ``baseline_category'' was either negative or not integer.")) } - if(any(reference_category - no_categories > 0)) { - stop(paste0( - "For variables ", - which(reference_category - no_categories > 0), - " the ``reference_category'' category was larger than the maximum category value." - )) + if(any(baseline_category - no_categories > 0)) { + stop(paste0("For variables ", + which(baseline_category - no_categories > 0), + " the ``baseline_category'' category was larger than the maximum category value.")) } } @@ -347,7 +343,7 @@ mrfSampler = function(no_states, interactions = interactions, thresholds = thresholds, variable_type = variable_type, - reference_category = reference_category, + baseline_category = baseline_category, iter = iter ) } diff --git a/cleanup b/cleanup index 4a1a609a..57975cc5 100755 --- a/cleanup +++ b/cleanup @@ -2,3 +2,4 @@ # Remove generated Makevars files rm -f src/Makevars rm -f src/Makevars.win +rm -f src/sources.mk diff --git a/configure b/configure index 2a547070..76bbc64e 100755 --- a/configure +++ b/configure @@ -1,9 +1,12 @@ #!/bin/sh -# Get flags from RcppParallel +# RcppParallel flags RCPP_PARALLEL_CPPFLAGS=`"${R_HOME}/bin/Rscript" -e "cat(RcppParallel::CxxFlags())"` RCPP_PARALLEL_LIBS=`"${R_HOME}/bin/Rscript" -e "cat(RcppParallel::LdFlags())"` +# Generate sources.mk using R +"${R_HOME}/bin/Rscript" inst/generate_makevars_sources.R > src/sources.mk + # Substitute into Makevars sed -e "s|@RCPP_PARALLEL_CPPFLAGS@|${RCPP_PARALLEL_CPPFLAGS}|" \ -e "s|@RCPP_PARALLEL_LIBS@|${RCPP_PARALLEL_LIBS}|" \ diff --git a/configure.win b/configure.win index 080f1688..794af7a0 100755 --- a/configure.win +++ b/configure.win @@ -1,8 +1,12 @@ #!/bin/sh +# RcppParallel flags RCPP_PARALLEL_CPPFLAGS=`"${R_HOME}/bin${R_ARCH_BIN}/Rscript.exe" -e "cat(RcppParallel::CxxFlags())"` RCPP_PARALLEL_LIBS=`"${R_HOME}/bin${R_ARCH_BIN}/Rscript.exe" -e "cat(RcppParallel::LdFlags())"` +# Generate sources.mk using R +"${R_HOME}/bin/Rscript" inst/generate_makevars_sources.R > src/sources.mk + # Substitute into Makevars.win sed -e "s|@RCPP_PARALLEL_CPPFLAGS@|${RCPP_PARALLEL_CPPFLAGS}|" \ -e "s|@RCPP_PARALLEL_LIBS@|${RCPP_PARALLEL_LIBS}|" \ diff --git a/dev/ggm-hmc/README.md b/dev/ggm-hmc/README.md new file mode 100644 index 00000000..414e86a6 --- /dev/null +++ b/dev/ggm-hmc/README.md @@ -0,0 +1,81 @@ +# GGM Constrained HMC Sampling + +This folder contains scripts for sampling precision matrices for Gaussian Graphical Models (GGMs) using constrained Hamiltonian Monte Carlo (HMC). The constrained HMC approach enforces exact zeros in the precision matrix (corresponding to missing edges in the graph) as hard constraints. + +## Overview + +- `run_constrained_hmc.py` - Python script that performs the constrained HMC sampling using JAX and [mici](https://github.com/matt-graham/mici) +- `run_constrained_hmc_subprocess.R` - R script that generates data, calls the Python sampler via subprocess, and processes results + +## Prerequisites + +### 1. Clone the mici example repository + +The Python environment is defined in a separate repository. Clone it as a sibling folder: + +```bash +cd /path/to/bgms/dev +git clone https://github.com/matt-graham/ggm-precision-constrained-hmc +``` + +Your folder structure should look like: +``` +bgms/dev/ +├── ggm-hmc/ # This folder +│ ├── README.md +│ ├── run_constrained_hmc.py +│ └── run_constrained_hmc_subprocess.R +└── ggm-precision-constrained-hmc/ # Cloned repo with uv environment + ├── .venv/ + ├── pyproject.toml + └── ... +``` + +**Alternative:** If you prefer a different location, adjust the `python_dir` path in `run_constrained_hmc_subprocess.R`. + +### 2. Set up the Python environment with uv + +Install [uv](https://docs.astral.sh/uv/) if you don't have it: + +```bash +curl -LsSf https://astral.sh/uv/install.sh | sh +``` + +Then create and sync the Python environment: + +```bash +cd /path/to/bgms/dev/ggm-precision-constrained-hmc +uv sync +``` + +This will create a `.venv` folder with Python 3.14 (free-threaded) and all required dependencies. + +## R Package Dependencies + +The R script requires: +- `bgms` (this package) +- `mvtnorm` +- `jsonlite` +- `RcppCNPy` (for reading numpy arrays) +- `BDgraph` (for generating G-Wishart samples) + +```r +install.packages(c("mvtnorm", "jsonlite", "RcppCNPy", "BDgraph")) +``` + +## Usage + +```r +# Set working directory to bgms root +setwd("/path/to/bgms") + +# Run the script +source("dev/ggm-hmc/run_constrained_hmc_subprocess.R") +``` + +The script will: +1. Generate simulated data from a sparse GGM +2. Run the bgms MH sampler for comparison +3. Call the Python constrained HMC sampler +4. Load and summarize the results + diff --git a/dev/ggm-hmc/run_constrained_hmc.py b/dev/ggm-hmc/run_constrained_hmc.py new file mode 100644 index 00000000..4c9f371b --- /dev/null +++ b/dev/ggm-hmc/run_constrained_hmc.py @@ -0,0 +1,188 @@ +""" +Constrained HMC sampler for GGM precision matrices. + +This script is called from R via subprocess to sample precision matrices +with specified zero constraints using the mici library. + +Usage: python run_constrained_hmc.py +""" + +import json +import sys +import time +import os +import numpy as np +import mici +import jax +from numpyro.distributions.transforms import LowerCholeskyTransform +import arviz # For diagnostics + +# Force CPU backend to avoid CUDA autotuner issues with small matrices +# Comment out this line to use GPU (may require CUDA driver updates) +jax.config.update("jax_default_device", jax.devices("cpu")[0]) + +jax.config.update("jax_enable_x64", True) + + +def main(): + # Load input data from R + with open(sys.argv[1], "r") as f: + data = json.load(f) + + n_variable = data["n_variable"] + n_obs = data["n_obs"] + S = np.array(data["scatter_matrix"], dtype=np.float64) + zero_indices = np.array(data["zero_indices"], dtype=np.int64) + n_warm_up_iter = data["n_warm_up_iter"] + n_main_iter = data["n_main_iter"] + n_chain = data["n_chain"] + seed = data["seed"] + output_file = data["output_file"] + samples_file = data["samples_file"] + + print(f"Loaded data: n_obs={n_obs}, n_variable={n_variable}, n_zero_pairs={len(zero_indices)}") + print(f"Chains: {n_chain}, Warmup: {n_warm_up_iter}, Samples: {n_main_iter}") + sys.stdout.flush() + + # Precompute Cholesky factor of scatter matrix for efficient trace computation + # tr(Omega S) = tr(L L^T S) = ||S_chol^T @ L||_F^2 + S_chol = np.linalg.cholesky(S) + + # Set up transformations + vector_to_cholesky = LowerCholeskyTransform() + cholesky_to_vector = vector_to_cholesky.inv + + def constr(u, zero_indices): + L = vector_to_cholesky(u) + return jax.vmap(lambda i, j: L[i] @ L[j])(*zero_indices.T) + + def neg_log_dens(u, n_obs, S_chol): + """ + Negative log posterior for GGM precision matrix. + + log p(Omega | X) ∝ (n/2) * log|Omega| - (1/2) * tr(Omega * S) + log p(Omega) + + where S = X'X is the scatter matrix. + We use a standard normal prior on the unconstrained parameters u. + + Optimizations: + - Avoid forming Omega = L @ L.T explicitly + - Use tr(L L^T S) = ||S_chol^T @ L||_F^2 where S = S_chol @ S_chol^T + - This reduces from 2 O(p³) ops to 1 O(p³) op + """ + L = vector_to_cholesky(u) + + # Log determinant: log|Omega| = 2 * sum(log(diag(L))) + log_det_Omega = 2 * jax.numpy.log(jax.numpy.diag(L)).sum() + + # Trace term: tr(Omega * S) = tr(L L^T S) = ||S_chol^T @ L||_F^2 + # This avoids explicitly forming Omega = L @ L.T + S_chol_T_L = S_chol.T @ L + trace_term = jax.numpy.sum(S_chol_T_L ** 2) + + # Gaussian likelihood: (n/2) * log|Omega| - (1/2) * tr(Omega * S) + log_likelihood = (n_obs / 2) * log_det_Omega - 0.5 * trace_term + + # Prior on unconstrained parameters (standard normal) + log_prior = -0.5 * (u**2).sum() + + # Return negative log posterior + return -(log_likelihood + log_prior) + + def trace_func(state): + L = vector_to_cholesky(state.pos) + return {"u": state.pos, "P": L @ L.T} + + # Create initial states + rng = np.random.default_rng(seed) + scale = 0.01 + + init_states = [] + for c in range(n_chain): + random_matrix = rng.standard_normal((n_variable, n_variable)) + P_init = np.identity(n_variable) + scale * random_matrix @ random_matrix.T + P_init[zero_indices[:, 0], zero_indices[:, 1]] = 0.0 + P_init[zero_indices[:, 1], zero_indices[:, 0]] = 0.0 + L_init = np.linalg.cholesky(P_init) + u_init = np.asarray(cholesky_to_vector(L_init)) + assert not np.any(np.isnan(u_init)), "NaN in initial state" + assert abs(constr(u_init, zero_indices).max()) < 1e-8, "Constraint violation" + init_states.append(u_init) + + print(f"Created {len(init_states)} initial states") + print("Running constrained HMC sampling...") + sys.stdout.flush() + + # Time the sampling + start_time = time.perf_counter() + + # Run sampling + results = mici.sample_constrained_hmc_chains( + n_warm_up_iter=n_warm_up_iter, + n_main_iter=n_main_iter, + init_states=init_states, + neg_log_dens=lambda u: neg_log_dens(u, n_obs, S_chol), + constr=lambda u: constr(u, zero_indices), + backend="jax", + seed=rng, + monitor_stats=("accept_stat", "n_step", "step_size"), + trace_funcs=[trace_func], + n_worker=1, + use_thread_pool=False, + ) + + end_time = time.perf_counter() + sampling_duration = end_time - start_time + + print(f"\nSampling complete! Duration: {sampling_duration:.2f} seconds") + sys.stdout.flush() + + + ess = arviz.ess(results.traces, var_names=["u"]) + r_hat = arviz.rhat(results.traces, var_names=["u"]) + + min_ess = float(ess.min().u.data) + max_rhat = float(r_hat.max().u.data) + + # Extract posterior samples of precision matrix + # Shape: (n_chain, n_main_iter, n_variable, n_variable) + P_samples = np.array(results.traces["P"]) + P_mean = P_samples.mean(axis=(0, 1)) + + print(f"Min ESS: {min_ess:.1f}, Max R-hat: {max_rhat:.3f}") + print(f"P_samples shape: {P_samples.shape}") + sys.stdout.flush() + + # Save full posterior samples as numpy array + # Reshape to 2D for RcppCNPy compatibility (doesn't support 4D arrays) + # Original shape: (n_chain, n_main_iter, n_variable, n_variable) + # Saved shape: (n_chain * n_main_iter, n_variable * n_variable) + P_samples_flat = P_samples.reshape(-1, n_variable * n_variable) + np.save(samples_file, P_samples_flat) + print(f"Posterior samples saved to: {samples_file} (flattened shape: {P_samples_flat.shape})") + + # Helper to convert NaN/Inf to None for JSON compatibility + def sanitize_for_json(val): + if isinstance(val, float) and (np.isnan(val) or np.isinf(val)): + return None + return val + + # Save summary results as JSON + output_data = { + "min_ess": sanitize_for_json(min_ess), + "max_rhat": sanitize_for_json(max_rhat), + "P_mean": P_mean.tolist(), + "n_chain": n_chain, + "n_iter": n_main_iter, + "samples_file": samples_file, + "sampling_duration_seconds": sampling_duration, + } + + with open(output_file, "w") as f: + json.dump(output_data, f) + + print(f"Results saved to: {output_file}") + + +if __name__ == "__main__": + main() diff --git a/dev/ggm-hmc/run_constrained_hmc_subprocess.R b/dev/ggm-hmc/run_constrained_hmc_subprocess.R new file mode 100644 index 00000000..feb86300 --- /dev/null +++ b/dev/ggm-hmc/run_constrained_hmc_subprocess.R @@ -0,0 +1,345 @@ +# ============================================================================= +# R script to run constrained HMC for GGM precision matrix sampling via Python +# Uses subprocess approach to work with uv environments (no --enable-shared needed) +# ============================================================================= + +library(bgms) +library(mvtnorm) +library(jsonlite) +library(RcppCNPy) # For reading numpy arrays +library(posterior) # For ESS calculation + +# ----------------------------------------------------------------------------- +# Helper function to convert vectorized upper triangle to matrix +# ----------------------------------------------------------------------------- +upper_tri_to_matrix <- function(vec, p) { + mat <- matrix(0, p, p) + idx <- 1 + for (j in 1:p) { + for (i in 1:j) { + mat[i, j] <- vec[idx] + mat[j, i] <- vec[idx] + idx <- idx + 1 + } + } + mat +} + +# ----------------------------------------------------------------------------- +# 1. Configuration +# ----------------------------------------------------------------------------- + +# Path to the folder with the python file +python_dir <- file.path(getwd(), "dev/ggm-hmc") +# Path to the uv virtual environment +python_bin <- "dev/ggm-precision-constrained-hmc/.venv/bin/python" + +# Temporary files for data exchange +tmp_dir <- tempdir() +input_file <- file.path(tmp_dir, "r_to_python_data.json") +output_file <- file.path(tmp_dir, "python_to_r_results.json") +samples_file <- file.path(tmp_dir, "python_P_samples.npy") + +# ----------------------------------------------------------------------------- +# 2. Generate data in R using mvtnorm +# ----------------------------------------------------------------------------- + +# Dimension and true precision +p <- 20 + +density <- 0.15 + +set.seed(123) +adj <- matrix(0, nrow = p, ncol = p) +adj[lower.tri(adj)] <- rbinom(p * (p - 1) / 2, size = 1, prob = density) +adj <- adj + t(adj) +# qgraph::qgraph(adj) +# sample true precision matrix from G-Wishart +Omega <- BDgraph::rgwish(1, adj = adj, b = p + sample(0:p, 1), D = diag(p)) +Sigma <- solve(Omega) +zapsmall(Omega) + +# Data +n <- 200 +x <- mvtnorm::rmvnorm(n = n, mean = rep(0, p), sigma = Sigma) + +# Identify zero elements in precision matrix (adj that don't exist) +zero_indices <- which(adj == 0 & row(adj) < col(Omega), arr.ind = TRUE) +zero_indices <- zero_indices - 1L # Convert to 0-based indexing for Python +cat("\nNumber of zero index pairs:", nrow(zero_indices), "\n") + + +mh_timing <- system.time({ +sampling_results <- bgms:::sample_ggm( + inputFromR = list(X = x), + prior_inclusion_prob = matrix(.5, p, p), + initial_edge_indicators = adj, + no_iter = 10000, + no_warmup = 1000, + no_chains = 4, + edge_selection = FALSE, + no_threads = 1, + seed = 123, + progress_type = 1 +)}) +cat("\nTime taken for bgms sampling:", mh_timing[3], "seconds\n") + + +# ----------------------------------------------------------------------------- +# 3. Save data to JSON for Python +# ----------------------------------------------------------------------------- + +# Compute sufficient statistics for GGM +# S = X'X (scatter matrix) +S <- crossprod(x) +stopifnot(all(dim(S) == c(p, p))) + +input_data <- list( + n_variable = p, + n_obs = as.integer(n), + scatter_matrix = as.matrix(S), + zero_indices = as.matrix(zero_indices), + n_warm_up_iter = 100L, + n_main_iter = 500L, + n_chain = 1L, + seed = 1234L, + output_file = output_file, + samples_file = samples_file +) + +write_json(input_data, input_file, auto_unbox = TRUE, matrix = "rowmajor") +cat("\nData saved to:", input_file, "\n") + +# ----------------------------------------------------------------------------- +# 4. Python script location +# ----------------------------------------------------------------------------- + +python_script <- file.path(python_dir, "run_constrained_hmc.py") + +# ----------------------------------------------------------------------------- +# 5. Run Python script +# ----------------------------------------------------------------------------- + +cat("\nRunning constrained HMC sampling...\n") +cat("This may take a few minutes.\n\n") + +# Run Python with the uv environment +# Use system() instead of system2() for real-time progress bar output +exit_code <- system( + paste(shQuote(python_bin), shQuote(python_script), shQuote(input_file)), + ignore.stdout = FALSE, + ignore.stderr = FALSE +) + +# ----------------------------------------------------------------------------- +# 6. Load results back into R +# ----------------------------------------------------------------------------- + +if (exit_code == 0 && file.exists(output_file)) { + results <- fromJSON(output_file) + min_ess <- results$min_ess + max_rhat <- results$max_rhat + sampling_duration <- results$sampling_duration_seconds + P_mean <- matrix(unlist(results$P_mean), nrow = p, ncol = p, byrow = TRUE) + + # Load full posterior samples from numpy file + # Saved as 2D: (n_chain * n_iter, p * p) due to RcppCNPy limitations + # Reshape back to 4D: (n_chain, n_iter, p, p) + P_samples_flat <- npyLoad(samples_file) + n_chain_loaded <- results$n_chain + n_iter_loaded <- results$n_iter + P_samples <- array( + t(P_samples_flat), # Transpose because R is column-major + dim = c(p, p, n_iter_loaded, n_chain_loaded) + ) + # Reorder dimensions to match expected (n_chain, n_iter, p, p) + P_samples <- aperm(P_samples, c(4, 3, 1, 2)) + cat("\nLoaded posterior samples with shape:", dim(P_samples), "\n") + + cat("\n=== Results ===\n") + cat(sprintf("Sampling duration: %.2f seconds\n", sampling_duration)) + + cat("\nTrue precision matrix (first 5x5):\n") + print(round(Omega[1:5, 1:5], 3)) + + # ----------------------------------------------------------------------------- + # 7. Compare estimated vs true precision + # ----------------------------------------------------------------------------- + + # Check if zeros are recovered + estimated_zeros <- which(abs(P_mean) < 0.1 & row(P_mean) != col(P_mean), arr.ind = TRUE) + true_zeros <- which(adj == 0 & row(adj) != col(adj), arr.ind = TRUE) + + cat("\nNumber of near-zero elements in estimate:", nrow(estimated_zeros), "\n") + cat("Number of true zero elements:", nrow(true_zeros), "\n") + + # Correlation between true and estimated + upper_tri <- upper.tri(Omega) + cor_precision <- cor(Omega[upper_tri], P_mean[upper_tri]) + cat(sprintf("\nCorrelation between true and estimated precision (upper triangle): %.3f\n", cor_precision)) + + # ----------------------------------------------------------------------------- + # 8. Compare with bgms MH sampler results + # ----------------------------------------------------------------------------- + + cat("\n=== Comparison: HMC vs MH (bgms) ===\n") + + # Extract bgms samples - stored as vectorized upper triangle + # samples matrix has dimensions: (p*(p+1)/2, n_iter) per chain + bgms_samples_list <- lapply(sampling_results, function(chain) { + if (!is.null(chain$samples)) { + t(chain$samples) # Transpose to (n_iter, p*(p+1)/2) + } else { + NULL + } + }) + bgms_samples_list <- Filter(Negate(is.null), bgms_samples_list) + + # Combine chains: (n_chains * n_iter, p*(p+1)/2) + bgms_samples_combined <- do.call(rbind, bgms_samples_list) + n_bgms_samples <- nrow(bgms_samples_combined) + n_bgms_chains <- length(bgms_samples_list) + n_bgms_iter <- n_bgms_samples / n_bgms_chains + + cat(sprintf("bgms: %d chains x %d iterations = %d samples\n", + n_bgms_chains, n_bgms_iter, n_bgms_samples)) + cat(sprintf("HMC: %d chains x %d iterations = %d samples\n", + n_chain_loaded, n_iter_loaded, n_chain_loaded * n_iter_loaded)) + + # Compute posterior means for bgms + bgms_mean_vec <- colMeans(bgms_samples_combined) + bgms_P_mean <- upper_tri_to_matrix(bgms_mean_vec, p) + + cat("\nbgms posterior mean precision (first 5x5):\n") + print(round(bgms_P_mean[1:5, 1:5], 3)) + + cat("\nHMC posterior mean precision (first 5x5):\n") + print(round(P_mean[1:5, 1:5], 3)) + + # ----------------------------------------------------------------------------- + # 9. Scatter plot: Compare edge estimates between methods + # ----------------------------------------------------------------------------- + + # Get upper triangle indices (off-diagonal only for edges) + edge_idx <- which(upper.tri(Omega), arr.ind = TRUE) + n_edges <- nrow(edge_idx) + + # Extract edge values from both methods + hmc_edges <- sapply(1:n_edges, function(k) P_mean[edge_idx[k, 1], edge_idx[k, 2]]) + bgms_edges <- sapply(1:n_edges, function(k) bgms_P_mean[edge_idx[k, 1], edge_idx[k, 2]]) + true_edges <- sapply(1:n_edges, function(k) Omega[edge_idx[k, 1], edge_idx[k, 2]]) + true_edges2 <- sapply(1:n_edges, function(k) adj[edge_idx[k, 1], edge_idx[k, 2]]) + true_zero <- true_edges2 == 0 + + # Create scatter plot comparing posterior means + par(mfrow = c(1, 3)) + + # Plot 1: HMC vs bgms posterior means + plot(bgms_edges, hmc_edges, + xlab = "bgms (MH) posterior mean", + ylab = "HMC posterior mean", + main = "Edge Estimates: HMC vs MH", + pch = ifelse(true_zero, 1, 19), + col = ifelse(true_zero, "red", "blue")) + abline(0, 1, lty = 2, col = "gray") + legend("topleft", + legend = c("True zero", "True non-zero"), + pch = c(1, 19), col = c("red", "blue"), bty = "n") + + # Plot 2: Both methods vs truth + plot(true_edges, hmc_edges, + xlab = "True precision", + ylab = "Estimated precision", + main = "Estimates vs Truth", + pch = 19, col = "blue") + points(true_edges, bgms_edges, pch = 17, col = "red") + abline(0, 1, lty = 2, col = "gray") + legend("topleft", + legend = c("HMC", "bgms (MH)"), + pch = c(19, 17), col = c("blue", "red"), bty = "n") + + bgms_incl_prob_mat <- upper_tri_to_matrix(colMeans(zapsmall(bgms_samples_combined) != 0), p) + bgms_incl_prob <- bgms_incl_prob_mat[lower.tri(bgms_incl_prob_mat)] + + # Compute HMC inclusion probabilities from P_samples + # P_samples has shape (n_chain, n_iter, p, p) + hmc_incl_prob <- numeric(n_edges) + for (k in 1:n_edges) { + i <- edge_idx[k, 1] + j <- edge_idx[k, 2] + edge_samples <- as.vector(P_samples[, , i, j]) # All chains, all iters + hmc_incl_prob[k] <- mean(zapsmall(edge_samples) != 0) + } + + # Plot 3: Posterior mean vs inclusion probability for both methods + xlim_range <- range(c(hmc_edges, bgms_edges)) + ylim_range <- c(0, 1) + + plot(hmc_edges, hmc_incl_prob, + xlab = "Posterior mean (edge weight)", + ylab = "Inclusion probability", + main = "Posterior Mean vs Inclusion Probability", + pch = 19, col = "blue", + xlim = xlim_range, ylim = ylim_range) + points(bgms_edges, bgms_incl_prob, pch = 17, col = "red") + abline(v = 0, lty = 2, col = "gray") + legend("topright", + legend = c("HMC", "bgms (MH)"), + pch = c(19, 17), col = c("blue", "red"), bty = "n") + + par(mfrow = c(1, 1)) + + # Correlation between methods + cat(sprintf("\nCorrelation between HMC and bgms edge estimates: %.4f\n", + cor(hmc_edges, bgms_edges))) + + # ----------------------------------------------------------------------------- + # 10. ESS per second comparison + # ----------------------------------------------------------------------------- + + cat("\n=== ESS/second Comparison ===\n") + + # HMC ESS (from Python, computed by arviz) + hmc_ess <- min_ess + hmc_time <- sampling_duration + hmc_ess_per_sec <- hmc_ess / hmc_time + + # bgms ESS - compute for each parameter using posterior package + # Reshape bgms samples to draws_array format: (n_iter, n_chains, n_params) + bgms_samples_array <- array( + NA, + dim = c(n_bgms_iter, n_bgms_chains, ncol(bgms_samples_combined)) + ) + for (c in 1:n_bgms_chains) { + start_idx <- (c - 1) * n_bgms_iter + 1 + end_idx <- c * n_bgms_iter + bgms_samples_array[, c, ] <- bgms_samples_combined[start_idx:end_idx, ] + } + + # Compute bulk ESS for each parameter + bgms_ess_bulk <- apply(bgms_samples_array, 3, function(x) { + posterior::ess_bulk(x) + }) + bgms_min_ess <- min(bgms_ess_bulk, na.rm = TRUE) + bgms_time <- mh_timing[3] + bgms_ess_per_sec <- bgms_min_ess / bgms_time + + print_efficiency <- function(hmc_time, hmc_ess, hmc_ess_per_sec, bgms_time, bgms_min_ess, bgms_ess_per_sec) { + cat(sprintf("\nHMC:\n")) + cat(sprintf(" Time: %.2f seconds\n", hmc_time)) + cat(sprintf(" Min ESS: %.1f\n", hmc_ess)) + cat(sprintf(" ESS/sec: %.2f\n", hmc_ess_per_sec)) + + cat(sprintf("\nbgms (MH):\n")) + cat(sprintf(" Time: %.2f seconds\n", bgms_time)) + cat(sprintf(" Min ESS: %.1f\n", bgms_min_ess)) + cat(sprintf(" ESS/sec: %.2f\n", bgms_ess_per_sec)) + + cat(sprintf("\nEfficiency ratio (HMC / MH): %.5fx\n", hmc_ess_per_sec / bgms_ess_per_sec)) + } + print_efficiency(hmc_time, hmc_ess, hmc_ess_per_sec, bgms_time, bgms_min_ess, bgms_ess_per_sec) + +} else { + cat("ERROR: Python script failed (exit code:", exit_code, ")\n") + cat("Check the Python output above for errors.\n") +} diff --git a/dev/numerical_analyses/bgm_blumecapel_normalization_PL.R b/dev/numerical_analyses/bgm_blumecapel_normalization_PL.R new file mode 100644 index 00000000..26359859 --- /dev/null +++ b/dev/numerical_analyses/bgm_blumecapel_normalization_PL.R @@ -0,0 +1,647 @@ +# ============================================================================== +# Blume–Capel Numerical Stability Study (reparametrized) +# File: dev/numerical_analyses/BCvar_normalization_PL.r +# +# Goal +# ---- +# Compare numerical stability of four ways to compute the Blume–Capel +# normalizing constant across a range of residual scores r, using the +# reparametrized form +# +# Z(r) = sum_{s=0}^C exp( θ_part(s) + s * r ), +# +# where +# +# θ_part(s) = θ_lin * (s - ref) + θ_quad * (s - ref)^2. +# +# This corresponds to the reformulated denominator where: +# - scores s are in {0, 1, ..., C}, +# - the quadratic/linear θ-part is in terms of the centered score (s - ref), +# - the “residual” r enters only through s * r. +# +# Methods (exactly four): +# 1) Direct +# Unbounded sum of exp(θ_part(s) + s * r). +# +# 2) Preexp +# Unbounded “power-chain” over s, precomputing exp(θ_part(s)) and +# reusing exp(r): +# Z(r) = sum_s exp(θ_part(s)) * (exp(r))^s . +# +# 3) Direct + max-bound +# Per-r max-term bound M(r) = max_s (θ_part(s) + s * r), +# computing +# Z(r) = exp(M(r)) * sum_s exp(θ_part(s) + s * r - M(r)), +# but returning only the *scaled* sum: +# sum_s exp(θ_part(s) + s * r - M(r)). +# +# 4) Preexp + max-bound +# Same max-term bound M(r) as in (3), but using the power-chain: +# sum_s exp(θ_part(s)) * exp(s * r - M(r)). +# +# References (for error calculation): +# - ref_unscaled = MPFR sum_s exp(θ_part(s) + s * r) +# - ref_scaled = MPFR sum_s exp(θ_part(s) + s * r - M(r)), +# where M(r) = max_s (θ_part(s) + s * r) in MPFR. +# +# Dependencies +# ------------ +# - Rmpfr +# +# Outputs +# ------- +# compare_bc_all_methods(...) returns a data.frame with: +# r : grid of residual scores +# direct : numeric, Σ_s exp(θ_part(s) + s * r) +# preexp : numeric, Σ_s via power-chain (unbounded) +# direct_bound : numeric, Σ_s exp(θ_part(s) + s * r - M(r)) +# preexp_bound : numeric, Σ_s via power-chain with max-term bound +# err_direct : |(direct - ref_unscaled)/ref_unscaled| +# err_preexp : |(preexp - ref_unscaled)/ref_unscaled| +# err_direct_bound : |(direct_bound - ref_scaled )/ref_scaled | +# err_preexp_bound : |(preexp_bound - ref_scaled )/ref_scaled | +# ref_unscaled : numeric MPFR reference (unbounded) +# ref_scaled : numeric MPFR reference (max-term scaled) +# +# Plotting helpers (unchanged interface): +# - plot_bc_four(res, ...) +# - summarize_bc_four(res) +# +# ============================================================================== + +library(Rmpfr) + +# ------------------------------------------------------------------------------ +# compare_bc_all_methods +# ------------------------------------------------------------------------------ +# Compute all four methods and MPFR references over a vector of r-values +# for the reparametrized Blume–Capel normalizing constant +# +# Z(r) = sum_{s=0}^C exp( θ_lin * (s - ref) + θ_quad * (s - ref)^2 + s * r ). +# +# Args: +# max_cat : integer, max category C (scores are s = 0..C) +# ref : integer, baseline category index for centering (s - ref) +# r_vals : numeric vector of r values to scan +# theta_lin : numeric, linear θ parameter +# theta_quad : numeric, quadratic θ parameter +# mpfr_prec : integer, MPFR precision (bits) for reference calculations +# +# Returns: +# data.frame with columns described in the file header (see “Outputs”). +# ------------------------------------------------------------------------------ + +compare_bc_all_methods <- function(max_cat = 10, + ref = 3, + r_vals = seq(-70, 70, length.out = 2000), + theta_lin = 0.12, + theta_quad = -0.02, + mpfr_prec = 256) { + + # --- score grid and θ-part --------------------------------------------------- + scores <- 0:max_cat # s = 0..C + centered <- scores - ref # (s - ref) + + # θ_part(s) = θ_lin*(s - ref) + θ_quad*(s - ref)^2 + theta_part <- theta_lin * centered + theta_quad * centered^2 + + # For the unbounded power-chain: exp(θ_part(s)) + exp_m <- exp(theta_part) + + # Output container ------------------------------------------------------------ + res <- data.frame( + r = r_vals, + direct = NA_real_, + preexp = NA_real_, + direct_bound = NA_real_, + preexp_bound = NA_real_, + err_direct = NA_real_, + err_preexp = NA_real_, + err_direct_bound = NA_real_, + err_preexp_bound = NA_real_, + ref_unscaled = NA_real_, + ref_scaled = NA_real_, + bound = NA_real_, # term_max = M(r), puur ter inspectie + theta_lin = theta_lin, + theta_quad = theta_quad, + max_cat = max_cat, + ref = ref + ) + + # --- MPFR constants independent of r ---------------------------------------- + tl_mpfr <- mpfr(theta_lin, mpfr_prec) + tq_mpfr <- mpfr(theta_quad, mpfr_prec) + sc_center_mpfr <- mpfr(centered, mpfr_prec) # (s - ref) + sc_raw_mpfr <- mpfr(scores, mpfr_prec) # s + + # --- Main loop over r -------------------------------------------------------- + for (i in seq_along(r_vals)) { + r <- r_vals[i] + + # Standard double-precision exponents + term <- theta_part + scores * r + + # ---------- MPFR references ---------- + r_mpfr <- mpfr(r, mpfr_prec) + term_mpfr <- tl_mpfr * sc_center_mpfr + + tq_mpfr * sc_center_mpfr * sc_center_mpfr + + sc_raw_mpfr * r_mpfr + + term_max_mpfr <- mpfr(max(asNumeric(term_mpfr)), mpfr_prec) + ref_unscaled_mpfr <- sum(exp(term_mpfr)) + ref_scaled_mpfr <- sum(exp(term_mpfr - term_max_mpfr)) + + # Store numeric references + res$ref_unscaled[i] <- asNumeric(ref_unscaled_mpfr) + res$ref_scaled[i] <- asNumeric(ref_scaled_mpfr) + + # ---------- (1) Direct (unbounded) ---------- + v_direct <- sum(exp(term)) + res$direct[i] <- v_direct + + # ---------- (2) Preexp (unbounded) ---------- + # Power-chain on exp(r): s = 0..max_cat, so start at s=0 with pow = 1 + eR <- exp(r) + pow <- 1.0 + S_pre <- 0.0 + for (j in seq_along(scores)) { + S_pre <- S_pre + exp_m[j] * pow + pow <- pow * eR + } + res$preexp[i] <- S_pre + + # ---------- (3) Direct + max-bound ---------- + term_max <- max(term) # M(r) + res$bound[i] <- term_max + + sum_direct_bound <- 0.0 + for (j in seq_along(scores)) { + sum_direct_bound <- sum_direct_bound + + exp(theta_part[j] + scores[j] * r - term_max) + } + res$direct_bound[i] <- sum_direct_bound + + # ---------- (4) Preexp + max-bound ---------- + pow_b <- exp(-term_max) # s = 0 → exp(0*r - term_max) + S_pre_b <- 0.0 + for (j in seq_along(scores)) { + S_pre_b <- S_pre_b + exp_m[j] * pow_b + pow_b <- pow_b * eR + } + res$preexp_bound[i] <- S_pre_b + + # ---------- Errors (vs MPFR) ---------- + res$err_direct[i] <- + asNumeric(abs((mpfr(v_direct, mpfr_prec) - ref_unscaled_mpfr) / ref_unscaled_mpfr)) + res$err_preexp[i] <- + asNumeric(abs((mpfr(S_pre, mpfr_prec) - ref_unscaled_mpfr) / ref_unscaled_mpfr)) + res$err_direct_bound[i] <- + asNumeric(abs((mpfr(sum_direct_bound, mpfr_prec) - ref_scaled_mpfr) / ref_scaled_mpfr)) + res$err_preexp_bound[i] <- + asNumeric(abs((mpfr(S_pre_b, mpfr_prec) - ref_scaled_mpfr) / ref_scaled_mpfr)) + } + + res +} + + + +# ------------------------------------------------------------------------------ +# plot_bc_four +# ------------------------------------------------------------------------------ +# Plot the four relative error curves on a log y-axis. +# +# Args: +# res : data.frame produced by compare_bc_all_methods() +# draw_order : character vector with any ordering of: +# c("err_direct","err_direct_bound","err_preexp_bound","err_preexp") +# alpha : named numeric vector (0..1) alphas for the same names +# lwd : line width +# +# Returns: (invisible) NULL. Draws a plot. +# +plot_bc_four = function(res, + draw_order = c("err_direct","err_direct_bound", + "err_preexp_bound","err_preexp"), + alpha = c(err_direct = 0.00, + err_direct_bound = 0.00, + err_preexp_bound = 0.40, + err_preexp = 0.40), + lwd = 2) { + + base_cols = c(err_direct = "#000000", + err_preexp = "#D62728", + err_direct_bound = "#1F77B4", + err_preexp_bound = "#9467BD") + + to_rgba = function(hex, a) rgb(t(col2rgb(hex))/255, alpha = a) + + cols = mapply(to_rgba, base_cols[draw_order], alpha[draw_order], + SIMPLIFY = TRUE, USE.NAMES = TRUE) + + vals = unlist(res[draw_order]) + vals = vals[is.finite(vals)] + ylim = if (length(vals)) { + q = stats::quantile(vals, c(.01, .99), na.rm = TRUE) + c(q[1] / 10, q[2] * 10) + } else c(1e-20, 1e-12) + + first = draw_order[1] + plot(res$r, res[[first]], type = "l", log = "y", + col = cols[[1]], lwd = lwd, ylim = ylim, + xlab = "r", ylab = "Relative error (vs MPFR)", + main = "Blume–Capel: Direct / Preexp / (Split) Bound") + + if (length(draw_order) > 1) { + for (k in 2:length(draw_order)) { + lines(res$r, res[[draw_order[k]]], col = cols[[k]], lwd = lwd) + } + } + + abline(h = .Machine$double.eps, col = "gray70", lty = 2) + + ## --- Theoretical bound where max term hits exp(709) + scores <- 0:res$max_cat[1] + centered <- scores - res$ref[1] + + # θ_part(s) = θ_lin*(s-ref) + θ_quad*(s-ref)^2 + theta_part <- res$theta_lin[1] * centered + + res$theta_quad[1] * centered * centered + + U <- 709 + pos <- scores > 0 + + if (any(pos)) { + r_up_vec <- (U - theta_part[pos]) / scores[pos] + r_up <- min(r_up_vec) + } else { + r_up <- Inf + } + + # Geen zinvolle beneden-grens voor overflow met s >= 0 + r_low <- -Inf + + if (is.finite(r_up)) { + abline(v = r_up, col = "darkgreen", lty = 2, lwd = 2) + } + + print(r_low) + print(r_up) + + legend("top", + legend = c("Direct", + "Direct + bound (split)", + "Preexp + bound (split)", + "Preexp") + [match(draw_order, + c("err_direct","err_direct_bound", + "err_preexp_bound","err_preexp"))], + col = cols, lwd = lwd, bty = "n") + + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ +# summarize_bc_four +# ------------------------------------------------------------------------------ +# Summarize accuracy per method. +# +# Args: +# res : data.frame from compare_bc_all_methods() +# +# Returns: +# data.frame with columns: Method, Mean, Median, Max, Finite +# +summarize_bc_four = function(res) { + cols = c("err_direct","err_direct_bound","err_preexp_bound","err_preexp") + labs = c("Direct","Direct+Bound(split)","Preexp+Bound(split)","Preexp") + mk = function(v){ + f = is.finite(v) & v > 0 + c(Mean=mean(v[f]), Median=median(v[f]), Max=max(v[f]), Finite=mean(f)) + } + out = t(sapply(cols, function(nm) mk(res[[nm]]))) + data.frame(Method=labs, out, row.names=NULL, check.names=FALSE) +} + +# ============================================================================== +# Example usage (uncomment to run locally) +# ------------------------------------------------------------------------------ +# res = compare_bc_all_methods( +# max_cat = 4, +# ref = 0, +# r_vals = seq(170, 175, length.out = 1000), +# theta_lin = 0, +# theta_quad = 1.00, +# mpfr_prec = 256 +# ) +# plot_bc_four(res, +# draw_order = c("err_direct","err_direct_bound","err_preexp_bound","err_preexp"), +# alpha = c(err_direct = 0.00, +# err_direct_bound = 1.00, +# err_preexp_bound = 1.00, +# err_preexp = 0.00), +# lwd = 1) +# print(summarize_bc_four(res), digits = 3) +# ============================================================================== + +scan_bc_configs <- function(max_cat_vec = c(4, 10), + ref_vec = c(0, 2), + theta_lin_vec = c(0.0, 0.12), + theta_quad_vec = c(-0.02, 0.0, 0.02), + r_vals = seq(-80, 80, length.out = 2000), + mpfr_prec = 256, + tol = 1e-12) { + + cfg_grid <- expand.grid( + max_cat = max_cat_vec, + ref = ref_vec, + theta_lin = theta_lin_vec, + theta_quad = theta_quad_vec, + KEEP.OUT.ATTRS = FALSE, + stringsAsFactors = FALSE + ) + + all_summaries <- vector("list", nrow(cfg_grid)) + + for (i in seq_len(nrow(cfg_grid))) { + cfg <- cfg_grid[i, ] + cat("Config", i, "of", nrow(cfg_grid), ":", + "max_cat =", cfg$max_cat, + "ref =", cfg$ref, + "theta_lin =", cfg$theta_lin, + "theta_quad =", cfg$theta_quad, "\n") + + res_i <- compare_bc_all_methods( + max_cat = cfg$max_cat, + ref = cfg$ref, + r_vals = r_vals, + theta_lin = cfg$theta_lin, + theta_quad = cfg$theta_quad, + mpfr_prec = mpfr_prec + ) + + summ_i <- summarize_bc_methods(res_i, tol = tol) + all_summaries[[i]] <- summ_i + } + + do.call(rbind, all_summaries) +} + +classify_bc_bound_methods <- function(res, tol = 1e-12, + eps_better = 1e-3) { + # tol : threshold for "good enough" relative error + # eps_better : multiplicative margin to call one method "better" when both good + + r <- res$r + eD <- res$err_direct_bound + eP <- res$err_preexp_bound + + finiteD <- is.finite(eD) & eD > 0 + finiteP <- is.finite(eP) & eP > 0 + + goodD <- finiteD & (eD < tol) + goodP <- finiteP & (eP < tol) + + state <- character(length(r)) + + for (i in seq_along(r)) { + if (!goodD[i] && !goodP[i]) { + state[i] <- "neither_good" + } else if (goodD[i] && !goodP[i]) { + state[i] <- "only_direct_good" + } else if (!goodD[i] && goodP[i]) { + state[i] <- "only_preexp_good" + } else { + # both good: compare which is better + # e.g. if preexp_bound error is at least eps_better times smaller than direct_bound + if (eP[i] <= eD[i] * (1 - eps_better)) { + state[i] <- "both_good_preexp_better" + } else if (eD[i] <= eP[i] * (1 - eps_better)) { + state[i] <- "both_good_direct_better" + } else { + # both good and within eps_better fraction: treat as "tie" + state[i] <- "both_good_similar" + } + } + } + + data.frame( + r = r, + err_direct_bound = eD, + err_preexp_bound = eP, + state = factor(state), + bound = res$bound, + max_cat = res$max_cat[1], + ref = res$ref[1], + theta_lin = res$theta_lin[1], + theta_quad = res$theta_quad[1], + stringsAsFactors = FALSE + ) +} + +summarize_bc_bound_classification <- function(class_df) { + # class_df is the output of classify_bc_bound_methods() + + r <- class_df$r + state <- as.character(class_df$state) + + if (length(r) == 0) { + return(class_df[FALSE, ]) # empty + } + + # Identify run boundaries where state changes + blocks <- list() + start_idx <- 1 + current_state <- state[1] + + for (i in 2:length(r)) { + if (state[i] != current_state) { + # close previous block + blocks[[length(blocks) + 1]] <- list( + state = current_state, + i_start = start_idx, + i_end = i - 1 + ) + # start new block + start_idx <- i + current_state <- state[i] + } + } + # close last block + blocks[[length(blocks) + 1]] <- list( + state = current_state, + i_start = start_idx, + i_end = length(r) + ) + + # Turn into a data.frame with r-intervals and some diagnostics + out_list <- vector("list", length(blocks)) + for (k in seq_along(blocks)) { + b <- blocks[[k]] + idx <- b$i_start:b$i_end + out_list[[k]] <- data.frame( + state = b$state, + r_min = min(r[idx]), + r_max = max(r[idx]), + # a few handy diagnostics per block: + max_err_direct_bound = max(class_df$err_direct_bound[idx], na.rm = TRUE), + max_err_preexp_bound = max(class_df$err_preexp_bound[idx], na.rm = TRUE), + min_bound = min(class_df$bound[idx], na.rm = TRUE), + max_bound = max(class_df$bound[idx], na.rm = TRUE), + n_points = length(idx), + max_cat = class_df$max_cat[1], + ref = class_df$ref[1], + theta_lin = class_df$theta_lin[1], + theta_quad = class_df$theta_quad[1], + stringsAsFactors = FALSE + ) + } + + do.call(rbind, out_list) +} + +# 1. Run the basic comparison +r_vals <- seq(0, 100, length.out = 2000) + +res4 <- compare_bc_all_methods( + max_cat = 4, + ref = 0, + r_vals = r_vals, + theta_lin = 0.12, + theta_quad = -0.02, + mpfr_prec = 256 +) + +# 2. Classify per-r which bound-method wins +class4 <- classify_bc_bound_methods(res4, tol = 1e-12, eps_better = 1e-3) + +# 3. Compress into r-intervals +summary4 <- summarize_bc_bound_classification(class4) +print(summary4, digits = 3) + + + + +simulate_bc_fast_safe <- function(param_grid, + r_vals = seq(-80, 80, length.out = 2000), + mpfr_prec = 256, + tol = 1e-12) { + # param_grid: data.frame with columns + # max_cat, ref, theta_lin, theta_quad + # r_vals : vector of residual r values + # tol : tolerance for "ok" numerics (relative error) + # + # Returns one big data.frame with columns: + # config_id, max_cat, ref, theta_lin, theta_quad, + # r, bound, fast_val, safe_val, + # err_fast, err_safe, ok_fast, ok_safe, + # ref_scaled (MPFR reference) + + if (!all(c("max_cat", "ref", "theta_lin", "theta_quad") %in% names(param_grid))) { + stop("param_grid must have columns: max_cat, ref, theta_lin, theta_quad") + } + + out_list <- vector("list", nrow(param_grid)) + + for (cfg_idx in seq_len(nrow(param_grid))) { + cfg <- param_grid[cfg_idx, ] + max_cat <- as.integer(cfg$max_cat) + ref <- as.integer(cfg$ref) + theta_lin <- as.numeric(cfg$theta_lin) + theta_quad <- as.numeric(cfg$theta_quad) + + # --- score grid and θ-part for this config -------------------------------- + scores <- 0:max_cat + centered <- scores - ref + + theta_part <- theta_lin * centered + theta_quad * centered^2 + exp_m <- exp(theta_part) # for fast method + + # MPFR constants + tl_mpfr <- mpfr(theta_lin, mpfr_prec) + tq_mpfr <- mpfr(theta_quad, mpfr_prec) + sc_center_mpfr <- mpfr(centered, mpfr_prec) + sc_raw_mpfr <- mpfr(scores, mpfr_prec) + + # Storage for this config + n_r <- length(r_vals) + res_cfg <- data.frame( + config_id = rep(cfg_idx, n_r), + max_cat = rep(max_cat, n_r), + ref = rep(ref, n_r), + theta_lin = rep(theta_lin, n_r), + theta_quad = rep(theta_quad, n_r), + r = r_vals, + bound = NA_real_, + fast_val = NA_real_, + safe_val = NA_real_, + err_fast = NA_real_, + err_safe = NA_real_, + ok_fast = NA, + ok_safe = NA, + ref_scaled = NA_real_, + stringsAsFactors = FALSE + ) + + # --- main loop over r for this config ------------------------------------- + for (i in seq_along(r_vals)) { + r <- r_vals[i] + + ## Double-precision exponents: + term <- theta_part + scores * r # θ_part(s) + s*r + term_max <- max(term) # M(r) = bound + res_cfg$bound[i] <- term_max + + ## MPFR reference (scaled with max-term): + r_mpfr <- mpfr(r, mpfr_prec) + term_mpfr <- tl_mpfr * sc_center_mpfr + + tq_mpfr * sc_center_mpfr * sc_center_mpfr + + sc_raw_mpfr * r_mpfr + term_max_mpfr <- mpfr(max(asNumeric(term_mpfr)), mpfr_prec) + ref_scaled_mpfr <- sum(exp(term_mpfr - term_max_mpfr)) + ref_scaled_num <- asNumeric(ref_scaled_mpfr) + res_cfg$ref_scaled[i] <- ref_scaled_num + + # --- SAFE: Direct + max-bound ------------------------------------------ + # Z_safe = sum_s exp(θ_part(s) + s*r - term_max) + safe_sum <- 0.0 + for (j in seq_along(scores)) { + safe_sum <- safe_sum + exp(theta_part[j] + scores[j] * r - term_max) + } + res_cfg$safe_val[i] <- safe_sum + + # --- FAST: Preexp + max-bound (power-chain) ---------------------------- + # Z_fast = sum_s exp(θ_part(s)) * exp(s*r - term_max) + eR <- exp(r) + pow_b <- exp(-term_max) # s = 0 → exp(0*r - term_max) + fast_sum <- 0.0 + for (j in seq_along(scores)) { + fast_sum <- fast_sum + exp_m[j] * pow_b + pow_b <- pow_b * eR + } + res_cfg$fast_val[i] <- fast_sum + + # --- Relative errors vs MPFR (scaled) ---------------------------------- + if (is.finite(ref_scaled_num) && ref_scaled_num > 0) { + res_cfg$err_safe[i] <- abs(safe_sum - ref_scaled_num) / ref_scaled_num + res_cfg$err_fast[i] <- abs(fast_sum - ref_scaled_num) / ref_scaled_num + } else { + res_cfg$err_safe[i] <- NA_real_ + res_cfg$err_fast[i] <- NA_real_ + } + + res_cfg$ok_safe[i] <- !is.na(res_cfg$err_safe[i]) && + is.finite(res_cfg$err_safe[i]) && + (res_cfg$err_safe[i] < tol) + + res_cfg$ok_fast[i] <- !is.na(res_cfg$err_fast[i]) && + is.finite(res_cfg$err_fast[i]) && + (res_cfg$err_fast[i] < tol) + } + + out_list[[cfg_idx]] <- res_cfg + } + + do.call(rbind, out_list) +} \ No newline at end of file diff --git a/dev/numerical_analyses/bgm_blumecapel_normalization_PL_extra.R b/dev/numerical_analyses/bgm_blumecapel_normalization_PL_extra.R new file mode 100644 index 00000000..43ebeb04 --- /dev/null +++ b/dev/numerical_analyses/bgm_blumecapel_normalization_PL_extra.R @@ -0,0 +1,900 @@ +############################################################ +# Blume–Capel normalization analysis: +# Numerical comparison of FAST vs SAFE exponentiation methods +# +# Objective +# --------- +# This script provides a full numerical investigation of two methods +# to compute the *scaled* Blume–Capel partition sum: +# +# Z_scaled(r) = sum_{s=0}^C exp( θ_part(s) + s*r - M(r) ) +# +# where +# θ_part(s) = θ_lin * (s - ref) + θ_quad * (s - ref)^2 +# and +# M(r) = max_s ( θ_part(s) + s*r ). +# +# We compare two computational approaches: +# +# SAFE = Direct computation : sum_s exp(θ_part + s*r - M(r)) +# FAST = Power-chain precompute : sum_s exp(θ_part(s)) * exp(s*r - M(r)) +# +# MPFR (256-bit) is used as the ground-truth reference. +# +# The goals are: +# +# 1. Determine each method's numerical stability across a wide range +# of (max_cat, ref, θ_lin, θ_quad, r). +# +# 2. Map all cases where FAST becomes inaccurate or produces NaN. +# +# 3. Identify the correct switching rule for the C++ implementation: +# if (bound <= ~709) use FAST +# else use SAFE +# +# where `bound = M(r)` is the maximum exponent before rescaling. +# +# 4. Produce plots and summary statistics to permanently document the +# reasoning behind this rule. +# +# Key numerical fact +# ------------------ +# exp(x) in IEEE double precision overflows at x ≈ 709.782712893. +# Therefore any exponent near ±709 is dangerous. +# +# Outcome summary +# --------------- +# - SAFE is stable across the entire tested range. +# - FAST is perfectly accurate **as long as bound ≤ ~709** +# - All FAST failures (NaN or large error) occur only when bound > ~709 +# - No FAST failures were observed below this threshold. +# +# This provides strong empirical justification for the C++ switching rule. +############################################################ + +library(Rmpfr) +library(dplyr) +library(ggplot2) + +############################################################ +# 1. Simulation function +# +# Simulates FAST vs SAFE across: +# - parameter grid (max_cat, ref, θ_lin, θ_quad) +# - range of r values +# +# Returns one large data.frame containing: +# - the computed bound M(r) +# - FAST and SAFE values +# - MPFR reference +# - relative errors +# - logical OK flags (err < tol) +############################################################ + +simulate_bc_fast_safe <- function(param_grid, + r_vals = seq(-80, 80, length.out = 2000), + mpfr_prec = 256, + tol = 1e-12) { + + if (!all(c("max_cat", "ref", "theta_lin", "theta_quad") %in% names(param_grid))) { + stop("param_grid must have columns: max_cat, ref, theta_lin, theta_quad") + } + + out_list <- vector("list", nrow(param_grid)) + + for (cfg_idx in seq_len(nrow(param_grid))) { + cfg <- param_grid[cfg_idx, ] + max_cat <- as.integer(cfg$max_cat) + ref <- as.integer(cfg$ref) + theta_lin <- as.numeric(cfg$theta_lin) + theta_quad <- as.numeric(cfg$theta_quad) + + # Score grid and θ(s) + scores <- 0:max_cat + centered <- scores - ref + theta_part <- theta_lin * centered + theta_quad * centered^2 + exp_m <- exp(theta_part) # used by FAST + + # Build MPFR constants + tl_mpfr <- mpfr(theta_lin, mpfr_prec) + tq_mpfr <- mpfr(theta_quad, mpfr_prec) + sc_center_mpfr <- mpfr(centered, mpfr_prec) + sc_raw_mpfr <- mpfr(scores, mpfr_prec) + + n_r <- length(r_vals) + + res_cfg <- data.frame( + config_id = rep(cfg_idx, n_r), + max_cat = rep(max_cat, n_r), + ref = rep(ref, n_r), + theta_lin = rep(theta_lin, n_r), + theta_quad = rep(theta_quad, n_r), + r = r_vals, + bound = NA_real_, + fast_val = NA_real_, + safe_val = NA_real_, + err_fast = NA_real_, + err_safe = NA_real_, + ok_fast = NA, + ok_safe = NA, + ref_scaled = NA_real_, + stringsAsFactors = FALSE + ) + + # Compute for all r + for (i in seq_along(r_vals)) { + r <- r_vals[i] + + term <- theta_part + scores * r # θ(s) + s*r + term_max <- max(term) # numerical bound + res_cfg$bound[i] <- term_max + + # MPFR scaled reference + r_mpfr <- mpfr(r, mpfr_prec) + term_mpfr <- tl_mpfr * sc_center_mpfr + + tq_mpfr * sc_center_mpfr * sc_center_mpfr + + sc_raw_mpfr * r_mpfr + term_max_mpfr <- mpfr(max(asNumeric(term_mpfr)), mpfr_prec) + ref_scaled_mpfr <- sum(exp(term_mpfr - term_max_mpfr)) + ref_scaled_num <- asNumeric(ref_scaled_mpfr) + res_cfg$ref_scaled[i] <- ref_scaled_num + + # SAFE method: direct evaluation + safe_sum <- 0.0 + for (j in seq_along(scores)) { + safe_sum <- safe_sum + exp(theta_part[j] + scores[j] * r - term_max) + } + res_cfg$safe_val[i] <- safe_sum + + # FAST method: preexp power-chain + eR <- exp(r) + pow_b <- exp(-term_max) + fast_sum <- 0.0 + for (j in seq_along(scores)) { + fast_sum <- fast_sum + exp_m[j] * pow_b + pow_b <- pow_b * eR + } + res_cfg$fast_val[i] <- fast_sum + + # Relative errors + if (is.finite(ref_scaled_num) && ref_scaled_num > 0) { + res_cfg$err_safe[i] <- abs(safe_sum - ref_scaled_num) / ref_scaled_num + res_cfg$err_fast[i] <- abs(fast_sum - ref_scaled_num) / ref_scaled_num + } + + res_cfg$ok_safe[i] <- !is.na(res_cfg$err_safe[i]) && + is.finite(res_cfg$err_safe[i]) && + (res_cfg$err_safe[i] < tol) + + res_cfg$ok_fast[i] <- !is.na(res_cfg$err_fast[i]) && + is.finite(res_cfg$err_fast[i]) && + (res_cfg$err_fast[i] < tol) + } + + out_list[[cfg_idx]] <- res_cfg + } + + do.call(rbind, out_list) +} + +############################################################ +# 2. Parameter grid and simulation +############################################################ + +param_grid <- expand.grid( + max_cat = c(10), + ref = c(0, 5, 10), + theta_lin = c(-0.5, 0.0, 0.5), + theta_quad = c(-0.2, 0.0, 0.2), + KEEP.OUT.ATTRS = FALSE, + stringsAsFactors = FALSE +) + +# Very wide r-range so that bound covers deep negative and deep positive +r_vals <- seq(-100, 100, length.out = 5001) + +tol <- 1e-12 + +sim_res <- simulate_bc_fast_safe( + param_grid = param_grid, + r_vals = r_vals, + mpfr_prec = 256, + tol = tol +) + +############################################################ +# 3. Post-processing: classify regions, log-errors, abs(bound) +############################################################ + +df <- sim_res %>% + mutate( + err_fast_clipped = pmax(err_fast, 1e-300), + err_safe_clipped = pmax(err_safe, 1e-300), + + log_err_fast = log10(err_fast_clipped), + log_err_safe = log10(err_safe_clipped), + + abs_bound = abs(bound), + + region = case_when( + ok_fast & ok_safe ~ "both_ok", + !ok_fast & ok_safe ~ "only_safe_ok", + ok_fast & !ok_safe ~ "only_fast_ok", + TRUE ~ "neither_ok" + ) + ) + +############################################################ +# 4. NaN analysis for FAST +# +# We explicitly check: +# +# Are there *any* NaN occurrences for FAST with |bound| < 709 ? +# +# This is essential: if NaN occurs for FAST even when |bound| is small, +# then the switching rule would fail. +############################################################ + +df_nan <- sim_res %>% filter(is.nan(err_fast)) + +nan_summary <- df_nan %>% + summarise( + n_nan = n(), + min_bound = min(bound, na.rm = TRUE), + max_bound = max(bound, na.rm = TRUE) + ) + +print(nan_summary) + +df_nan_inside <- df_nan %>% filter(abs(bound) < 709) + +cat("\nNumber of FAST NaN cases with |bound| < 709: ", + nrow(df_nan_inside), "\n\n") + +############################################################ +# 5. FAST and SAFE plots vs bound +# +# We also explicitly count how many cases fail (ok_* == FALSE) +# while |bound| < 709. If the switching rule is correct, this +# number should be zero for FAST in the region where we intend +# to use it. +############################################################ + +# Count failures for FAST and SAFE when |bound| < 709 +fast_fail_inside <- df %>% + filter(abs(bound) < 709, !ok_fast) %>% + nrow() + +safe_fail_inside <- df %>% + filter(abs(bound) < 709, !ok_safe) %>% + nrow() + +cat("\nFAST failures with |bound| < 709:", fast_fail_inside, "\n") +cat("SAFE failures with |bound| < 709:", safe_fail_inside, "\n\n") + +# FAST +ggplot(df, aes(x = bound, y = log_err_fast, colour = region)) + + geom_point(alpha = 0.3, size = 0.6, na.rm = TRUE) + + geom_hline(yintercept = log10(tol), linetype = 2) + + geom_vline(xintercept = 709, linetype = 2) + + geom_vline(xintercept = -709, linetype = 2) + + scale_color_manual(values = c( + both_ok = "darkgreen", + only_safe_ok = "orange", + only_fast_ok = "blue", + neither_ok = "red" + )) + + labs( + x = "bound = max_s (theta_part(s) + s*r)", + y = "log10(relative error) of FAST", + colour = "region", + subtitle = paste( + "FAST failures with |bound| < 709:", fast_fail_inside + ) + ) + + ggtitle("FAST method vs bound") + + theme_minimal() + +# SAFE +ggplot(df, aes(x = bound, y = log_err_safe, colour = region)) + + geom_point(alpha = 0.3, size = 0.6, na.rm = TRUE) + + geom_hline(yintercept = log10(tol), linetype = 2) + + geom_vline(xintercept = 709, linetype = 2) + + geom_vline(xintercept = -709, linetype = 2) + + scale_color_manual(values = c( + both_ok = "darkgreen", + only_safe_ok = "orange", + only_fast_ok = "blue", + neither_ok = "red" + )) + + labs( + x = "bound = max_s (theta_part(s) + s*r)", + y = "log10(relative error) of SAFE", + colour = "region", + subtitle = paste( + "SAFE failures with |bound| < 709:", safe_fail_inside + ) + ) + + ggtitle("SAFE method vs bound") + + theme_minimal() + + +############################################################ +# 6. Fraction of configurations per |bound|-bin +############################################################ + +df_bins <- df %>% + filter(is.finite(bound)) %>% + mutate( + abs_bound = abs(bound), + bound_bin = cut( + abs_bound, + breaks = seq(0, max(abs_bound, na.rm = TRUE) + 10, by = 10), + include_lowest = TRUE + ) + ) %>% + group_by(bound_bin) %>% + summarise( + mid_abs_bound = mean(abs_bound, na.rm = TRUE), + frac_fast_ok = mean(ok_fast, na.rm = TRUE), + frac_safe_ok = mean(ok_safe, na.rm = TRUE), + n = n(), + .groups = "drop" + ) + +ggplot(df_bins, aes(x = mid_abs_bound)) + + geom_line(aes(y = frac_fast_ok, colour = "FAST ok")) + + geom_line(aes(y = frac_safe_ok, colour = "SAFE ok")) + + geom_vline(xintercept = 709, linetype = 2) + + scale_colour_manual(values = c("FAST ok" = "blue", "SAFE ok" = "darkgreen")) + + labs( + x = "|bound| bin center", + y = "fraction of configurations with err < tol", + colour = "" + ) + + ggtitle("FAST vs SAFE numerical stability by |bound|") + + theme_minimal() + +############################################################ +# 7. Summary printed to console +############################################################ + +cat("\n================ SUMMARY =================\n") +print(nan_summary) + +cat("\nFAST NaN cases with |bound| < 709: ", + nrow(df_nan_inside), "\n\n") + +cat(" +Interpretation: +-------------- +- The SAFE method (direct + bound) remains stable and accurate across the + entire tested parameter and residual range. + +- The FAST method (preexp + bound) is extremely accurate when the maximum + exponent before rescaling, `bound = M(r)`, satisfies: + + |bound| ≤ ~709 + +- As soon as bound exceeds approximately +709, FAST becomes unstable: + * large numerical error + * or NaN (observed systematically) + * No such failures appear below this threshold. + +C++ Implementation Rule (recommended): +-------------------------------------- +if (bound <= 709.0) { + // FAST: preexp + bound (power-chain) +} else { + // SAFE: direct + bound +} + +This script constitutes the full reproducible analysis supporting the choice +of this switching threshold in the C++ Blume–Capel normalization code. +") + +############################################################ +# End of script +############################################################ + + + + + + + + +############################################################ +# Blume–Capel probability analysis: +# Numerical comparison of FAST vs SAFE probability evaluation +# +# Objective +# --------- +# This script provides a numerical investigation of two methods +# to compute the *probabilities* under the Blume–Capel +# pseudolikelihood: +# +# p_s(r) = exp( θ_part(s) + s*r - M(r) ) / Z_scaled(r) +# +# where +# θ_part(s) = θ_lin * (s - ref) + θ_quad * (s - ref)^2 +# M(r) = max_s ( θ_part(s) + s*r ) +# Z_scaled = sum_s exp( θ_part(s) + s*r - M(r) ) +# +# We compare two implementations: +# +# SAFE = direct exponentials with numerical bound M(r) +# FAST = preexp + power-chain for exp(s*r - M(r)) +# +# MPFR (256-bit) is used as the ground-truth reference. +# +# Goals +# ----- +# 1. Check numerical stability of SAFE vs FAST for probabilities +# across wide ranges of (max_cat, ref, θ_lin, θ_quad, r). +# 2. Confirm that the same switching rule used for the +# normalization carries over safely to probabilities: +# +# FAST is used only if +# +# |M(r)| <= EXP_BOUND AND +# pow_bound = max_cat * r - M(r) <= EXP_BOUND +# +# where EXP_BOUND ≈ 709. +# +# 3. Document the error behaviour in terms of: +# - max absolute difference per probability vector +# - max relative difference +# - KL divergence to MPFR reference. +# +# Outcome (to be checked empirically) +# ----------------------------------- +# - SAFE should be stable across the tested ranges. +# - FAST should exhibit negligible error whenever the +# switching bounds are satisfied. +# +############################################################ + +library(Rmpfr) +library(dplyr) +library(ggplot2) + +EXP_BOUND <- 709 # double overflow limit for exp() + + +############################################################ +# 1. Helper: MPFR probability reference for a single config +############################################################ + +bc_prob_ref_mpfr <- function(max_cat, ref, theta_lin, theta_quad, + r_vals, + mpfr_prec = 256) { + # Categories and centered scores + scores <- 0:max_cat + centered <- scores - ref + + # MPFR parameters + tl <- mpfr(theta_lin, mpfr_prec) + tq <- mpfr(theta_quad, mpfr_prec) + sc <- mpfr(scores, mpfr_prec) + s0 <- mpfr(centered, mpfr_prec) + + n_r <- length(r_vals) + n_s <- length(scores) + + # reference probability matrix (rows = r, cols = s) + P_ref <- matrix(NA_real_, nrow = n_r, ncol = n_s) + + for (i in seq_len(n_r)) { + r_mp <- mpfr(r_vals[i], mpfr_prec) + + # exponent(s) = θ_part(s) + s*r + term <- tl * s0 + tq * s0 * s0 + sc * r_mp + + # numeric bound M(r) + term_max <- max(asNumeric(term)) + term_max_mp <- mpfr(term_max, mpfr_prec) + + num <- exp(term - term_max_mp) # scaled numerators + Z <- sum(num) + p <- num / Z + + P_ref[i, ] <- asNumeric(p) + } + + P_ref +} + + +############################################################ +# 2. SAFE probabilities (double) for a single config +############################################################ + +bc_prob_safe <- function(max_cat, ref, theta_lin, theta_quad, + r_vals) { + scores <- 0:max_cat + centered <- scores - ref + theta_part <- theta_lin * centered + theta_quad * centered^2 + + n_r <- length(r_vals) + n_s <- length(scores) + + P_safe <- matrix(NA_real_, nrow = n_r, ncol = n_s) + + for (i in seq_len(n_r)) { + r <- r_vals[i] + + # exponents before scaling + exps <- theta_part + scores * r + b <- max(exps) + + numer <- exp(exps - b) + denom <- sum(numer) + + # NO fallback here; let denom=0 or non-finite propagate + p <- numer / denom + + P_safe[i, ] <- p + } + + P_safe +} + + + +############################################################ +# 3. FAST probabilities (double) for a single config +# +# This mirrors what a C++ compute_probs_blume_capel(FAST) +# implementation would do: precompute exp(theta_part), +# then use a power chain for exp(s*r - b). +############################################################ + +bc_prob_fast <- function(max_cat, ref, theta_lin, theta_quad, + r_vals) { + scores <- 0:max_cat + centered <- scores - ref + theta_part <- theta_lin * centered + theta_quad * centered^2 + exp_theta <- exp(theta_part) + + n_r <- length(r_vals) + n_s <- length(scores) + + P_fast <- matrix(NA_real_, nrow = n_r, ncol = n_s) + bounds <- numeric(n_r) + pow_bounds <- numeric(n_r) + + for (i in seq_len(n_r)) { + r <- r_vals[i] + + # exponents before scaling + exps <- theta_part + scores * r + b <- max(exps) + bounds[i] <- b + + # pow_bound = max_s (s*r - b) is attained at s = max_cat + pow_bounds[i] <- max_cat * r - b + + eR <- exp(r) + pow <- exp(-b) + + numer <- numeric(n_s) + denom <- 0.0 + + for (j in seq_along(scores)) { + numer[j] <- exp_theta[j] * pow + denom <- denom + numer[j] + pow <- pow * eR + } + + # Again: NO fallback, just divide and let problems show + p <- numer / denom + + P_fast[i, ] <- p + } + + list( + probs = P_fast, + bound = bounds, + pow_bound = pow_bounds + ) +} + + + +############################################################ +# 4. Main simulation: +# Explore param_grid × r_vals and compare: +# - P_ref (MPFR) +# - P_safe +# - P_fast +############################################################ + +simulate_bc_prob_fast_safe <- function(param_grid, + r_vals, + mpfr_prec = 256, + tol_prob = 1e-12) { + + if (!all(c("max_cat", "ref", "theta_lin", "theta_quad") %in% names(param_grid))) { + stop("param_grid must have columns: max_cat, ref, theta_lin, theta_quad") + } + + out_list <- vector("list", nrow(param_grid)) + + for (cfg_idx in seq_len(nrow(param_grid))) { + cfg <- param_grid[cfg_idx, ] + max_cat <- as.integer(cfg$max_cat) + ref <- as.integer(cfg$ref) + theta_lin <- as.numeric(cfg$theta_lin) + theta_quad <- as.numeric(cfg$theta_quad) + + n_r <- length(r_vals) + + # Reference + P_ref <- bc_prob_ref_mpfr(max_cat, ref, theta_lin, theta_quad, + r_vals, mpfr_prec = mpfr_prec) + # SAFE + P_safe <- bc_prob_safe(max_cat, ref, theta_lin, theta_quad, + r_vals) + # FAST (+ bounds) + fast_res <- bc_prob_fast(max_cat, ref, theta_lin, theta_quad, + r_vals) + P_fast <- fast_res$probs + bound <- fast_res$bound + pow_bound <- fast_res$pow_bound + + # Error metrics per r + max_abs_fast <- numeric(n_r) + max_rel_fast <- numeric(n_r) + kl_fast <- numeric(n_r) + + max_abs_safe <- numeric(n_r) + max_rel_safe <- numeric(n_r) + kl_safe <- numeric(n_r) + + # Helper: KL divergence D(p || q) + kl_div <- function(p, q) { + # If either vector has non-finite entries, KL is undefined → NA + if (!all(is.finite(p)) || !all(is.finite(q))) { + return(NA_real_) + } + + # Valid domain for KL: where both p and q are strictly positive + mask <- (p > 0) & (q > 0) + + # mask may contain NA → remove NA via na.rm=TRUE + if (!any(mask, na.rm = TRUE)) { + return(NA_real_) + } + + sum(p[mask] * (log(p[mask]) - log(q[mask]))) + } + + + for (i in seq_len(n_r)) { + p_ref <- P_ref[i, ] + p_safe <- P_safe[i, ] + p_fast <- P_fast[i, ] + + # max abs diff + max_abs_fast[i] <- max(abs(p_fast - p_ref)) + max_abs_safe[i] <- max(abs(p_safe - p_ref)) + + # max relative diff (avoid divide-by-zero) + rel_fast <- abs(p_fast - p_ref) + rel_safe <- abs(p_safe - p_ref) + + rel_fast[p_ref > 0] <- rel_fast[p_ref > 0] / p_ref[p_ref > 0] + rel_safe[p_ref > 0] <- rel_safe[p_ref > 0] / p_ref[p_ref > 0] + + rel_fast[p_ref == 0] <- 0 + rel_safe[p_ref == 0] <- 0 + + max_rel_fast[i] <- max(rel_fast) + max_rel_safe[i] <- max(rel_safe) + + # KL + kl_fast[i] <- kl_div(p_ref, p_fast) + kl_safe[i] <- kl_div(p_ref, p_safe) + } + + # "ok" flags using tol_prob on max_abs + ok_fast <- is.finite(max_abs_fast) & (max_abs_fast < tol_prob) + ok_safe <- is.finite(max_abs_safe) & (max_abs_safe < tol_prob) + + # FAST switching condition as in C++: + # use FAST only if |bound| <= EXP_BOUND and pow_bound <= EXP_BOUND + use_fast <- (abs(bound) <= EXP_BOUND) & (pow_bound <= EXP_BOUND) + + res_cfg <- data.frame( + config_id = rep(cfg_idx, n_r), + max_cat = rep(max_cat, n_r), + ref = rep(ref, n_r), + theta_lin = rep(theta_lin, n_r), + theta_quad = rep(theta_quad, n_r), + r = r_vals, + bound = bound, + pow_bound = pow_bound, + use_fast = use_fast, + max_abs_fast = max_abs_fast, + max_rel_fast = max_rel_fast, + kl_fast = kl_fast, + max_abs_safe = max_abs_safe, + max_rel_safe = max_rel_safe, + kl_safe = kl_safe, + ok_fast = ok_fast, + ok_safe = ok_safe, + stringsAsFactors = FALSE + ) + + out_list[[cfg_idx]] <- res_cfg + } + + do.call(rbind, out_list) +} + + +############################################################ +# 5. Example simulation setup +############################################################ + +# Parameter grid similar in spirit to the BC normalization script +param_grid <- expand.grid( + max_cat = c(4, 10), # Blume–Capel max categories (example) + ref = c(0, 2, 4, 5, 10), # include both interior & boundary refs + theta_lin = c(-0.5, 0.0, 0.5), + theta_quad = c(-0.2, 0.0, 0.2), + KEEP.OUT.ATTRS = FALSE, + stringsAsFactors = FALSE +) + +# Wide r-range; adjust as needed to match your empirical residuals +r_vals <- seq(-80, 80, length.out = 2001) + +tol_prob <- 1e-12 + +sim_probs <- simulate_bc_prob_fast_safe( + param_grid = param_grid, + r_vals = r_vals, + mpfr_prec = 256, + tol_prob = tol_prob +) + + +############################################################ +# 6. Post-processing and diagnostics +############################################################ + +df <- sim_probs %>% + mutate( + abs_bound = abs(bound), + region = case_when( + use_fast & ok_fast ~ "fast_ok_when_used", + use_fast & !ok_fast ~ "fast_bad_when_used", + !use_fast & ok_safe ~ "safe_ok_when_used", + !use_fast & !ok_safe ~ "safe_bad_when_used" + ) + ) + +# Check: any bad FAST cases *within* the intended FAST region? +fast_bad_inside <- df %>% + filter(use_fast, !ok_fast) + +cat("\nNumber of FAST probability failures where use_fast == TRUE: ", + nrow(fast_bad_inside), "\n\n") + +# Also track purely based on bounds (even if not marked use_fast) +fast_bad_bound_region <- df %>% + filter(abs(bound) <= EXP_BOUND, + pow_bound <= EXP_BOUND, + !ok_fast) + +cat("Number of FAST probability failures with |bound| <= 709 & pow_bound <= 709: ", + nrow(fast_bad_bound_region), "\n\n") + + +############################################################ +# 7. Plots: error vs bound (FAST only) +############################################################ + +df_fast <- df %>% + filter(use_fast) %>% + mutate( + log10_max_abs_fast = log10(pmax(max_abs_fast, 1e-300)) + ) + +ggplot(df_fast, aes(x = bound, y = log10_max_abs_fast)) + + geom_point(alpha = 0.3, size = 0.6) + + geom_hline(yintercept = log10(tol_prob), linetype = 2, colour = "darkgreen") + + geom_vline(xintercept = EXP_BOUND, linetype = 2, colour = "red") + + geom_vline(xintercept = -EXP_BOUND, linetype = 2, colour = "red") + + labs( + x = "bound = max_s (θ_part(s) + s*r)", + y = "log10(max absolute error) of FAST p_s(r)", + title = "FAST Blume–Capel probabilities vs bound (used region only)", + subtitle = paste( + "FAST failures in use_fast region:", nrow(fast_bad_inside) + ) + ) + + theme_minimal() + + +############################################################ +# 8. Binned summary by |bound| +############################################################ + +df_bins <- df %>% + mutate( + abs_bound = abs(bound), + bound_bin = cut( + abs_bound, + breaks = seq(0, max(abs_bound, na.rm = TRUE) + 10, by = 10), + include_lowest = TRUE + ) + ) %>% + group_by(bound_bin) %>% + summarise( + mid_abs_bound = mean(abs_bound, na.rm = TRUE), + frac_fast_ok = mean(ok_fast[use_fast], na.rm = TRUE), + frac_safe_ok = mean(ok_safe[!use_fast], na.rm = TRUE), + max_abs_fast_99 = quantile(max_abs_fast[use_fast], 0.99, na.rm = TRUE), + max_abs_safe_99 = quantile(max_abs_safe[!use_fast], 0.99, na.rm = TRUE), + n = n(), + .groups = "drop" + ) + +ggplot(df_bins, aes(x = mid_abs_bound)) + + geom_line(aes(y = frac_fast_ok, colour = "FAST ok (used)"), na.rm = TRUE) + + geom_line(aes(y = frac_safe_ok, colour = "SAFE ok (used)"), na.rm = TRUE) + + geom_vline(xintercept = EXP_BOUND, linetype = 2) + + scale_colour_manual(values = c( + "FAST ok (used)" = "blue", + "SAFE ok (used)" = "darkgreen" + )) + + labs( + x = "|bound| bin center", + y = "fraction of configurations with max_abs_error < tol_prob", + colour = "", + title = "Numerical stability of Blume–Capel probabilities by |bound|" + ) + + theme_minimal() + + +############################################################ +# 9. Console summary +############################################################ + +cat("\n================ PROBABILITY SUMMARY =================\n") + +cat("Total rows in simulation:", nrow(df), "\n\n") + +cat("FAST probability failures where use_fast == TRUE: ", + nrow(fast_bad_inside), "\n") +cat("FAST probability failures with |bound| <= 709 & pow_bound <= 709: ", + nrow(fast_bad_bound_region), "\n\n") + +cat("Typical 99th percentile max_abs_error per |bound|-bin (FAST used):\n") +print( + df_bins %>% + select(bound_bin, mid_abs_bound, max_abs_fast_99) %>% + arrange(mid_abs_bound), + digits = 4 +) + +cat(" +Interpretation guide +-------------------- +- `ok_fast`/`ok_safe` are defined by max absolute error vs MPFR reference + being below tol_prob (default 1e-12). + +- `use_fast` encodes the **intended** C++ switching rule: + use_fast = (|bound| <= 709) & (pow_bound <= 709) + +- Ideally: + * `fast_bad_inside` should be empty or extremely rare, + showing that FAST is safe whenever used. + * errors for SAFE should be negligible everywhere. + +You can tighten the switching margin if needed (e.g. require +`pow_bound <= 700`) by adjusting `use_fast` in the code above. +") \ No newline at end of file diff --git a/dev/numerical_analyses/bgm_blumecapel_probs_PL.R b/dev/numerical_analyses/bgm_blumecapel_probs_PL.R new file mode 100644 index 00000000..6d887ba4 --- /dev/null +++ b/dev/numerical_analyses/bgm_blumecapel_probs_PL.R @@ -0,0 +1,248 @@ +############################################################ +# Blume–Capel probabilities: +# Numerical comparison of 4 methods vs MPFR reference +# +# Methods: +# - direct_unscaled : naive softmax +# - direct_bound : softmax with subtraction of M(r) +# - preexp_unscaled : preexp(theta_part) + power chain (no bound) +# - preexp_bound : preexp(theta_part) + power chain (with bound) +# +# Reference: +# - MPFR softmax with scaling by M(r) +############################################################ + +library(Rmpfr) +library(dplyr) +library(ggplot2) + +EXP_BOUND <- 709 + +############################################################ +# 1. Compare 4 methods for one BC configuration +############################################################ + +compare_bc_prob_4methods_one <- function(max_cat, + ref, + theta_lin, + theta_quad, + r_vals, + mpfr_prec = 256) { + + s_vals <- 0:max_cat + c_vals <- s_vals - ref + n_s <- length(s_vals) + n_r <- length(r_vals) + + # theta_part(s) + theta_part_num <- theta_lin * c_vals + theta_quad * c_vals^2 + + # MPFR parameters + tl_mp <- mpfr(theta_lin, mpfr_prec) + tq_mp <- mpfr(theta_quad, mpfr_prec) + s_mp <- mpfr(s_vals, mpfr_prec) + c_mp <- mpfr(c_vals, mpfr_prec) + + # Precompute for preexp methods + exp_theta <- exp(theta_part_num) + + res <- data.frame( + r = r_vals, + bound = NA_real_, + pow_bound = NA_real_, + err_direct = NA_real_, + err_bound = NA_real_, + err_preexp = NA_real_, + err_preexp_bound= NA_real_ + ) + + for (i in seq_len(n_r)) { + r <- r_vals[i] + r_mp <- mpfr(r, mpfr_prec) + + ## MPFR reference probabilities (softmax with scaling) + term_mp <- tl_mp * c_mp + + tq_mp * c_mp * c_mp + + s_mp * r_mp + + M_num <- max(asNumeric(term_mp)) + M_mp <- mpfr(M_num, mpfr_prec) + + num_ref_mp <- exp(term_mp - M_mp) + Z_ref_mp <- sum(num_ref_mp) + p_ref_mp <- num_ref_mp / Z_ref_mp + p_ref <- asNumeric(p_ref_mp) + + ## Double: exponents + term_num <- theta_part_num + s_vals * r + M <- max(term_num) + res$bound[i] <- M + res$pow_bound[i] <- max_cat * r - M + + ## (1) direct_unscaled + num_dir <- exp(term_num) + den_dir <- sum(num_dir) + p_dir <- num_dir / den_dir + + ## (2) direct_bound + num_b <- exp(term_num - M) + den_b <- sum(num_b) + p_b <- num_b / den_b + + ## (3) preexp_unscaled + eR <- exp(r) + pow <- eR + num_pre <- numeric(n_s) + den_pre <- 0.0 + + # s = 0 term + num_pre[1] <- exp_theta[1] * 1.0 + den_pre <- den_pre + num_pre[1] + + if (max_cat >= 1) { + for (s in 1:max_cat) { + num_pre[s + 1] <- exp_theta[s + 1] * pow + den_pre <- den_pre + num_pre[s + 1] + pow <- pow * eR + } + } + p_pre <- num_pre / den_pre + + ## (4) preexp_bound + eR2 <- exp(r) + pow_b <- exp(-M) + num_preB <- numeric(n_s) + den_preB <- 0.0 + + for (s in 0:max_cat) { + idx <- s + 1 + num_preB[idx] <- exp_theta[idx] * pow_b + den_preB <- den_preB + num_preB[idx] + pow_b <- pow_b * eR2 + } + p_preB <- num_preB / den_preB + + ## Relative errors vs MPFR reference on non-negligible support + tau <- 1e-15 # <-- tweak this + + support_mask <- p_ref >= tau + if (!any(support_mask)) { + support_mask <- p_ref == max(p_ref) # degenerate case: all tiny, pick the max + } + + rel_direct <- abs(p_dir - p_ref)[support_mask] / p_ref[support_mask] + rel_bound <- abs(p_b - p_ref)[support_mask] / p_ref[support_mask] + rel_preexp <- abs(p_pre - p_ref)[support_mask] / p_ref[support_mask] + rel_preB <- abs(p_preB - p_ref)[support_mask] / p_ref[support_mask] + + res$err_direct[i] <- max(rel_direct) + res$err_bound[i] <- max(rel_bound) + res$err_preexp[i] <- max(rel_preexp) + res$err_preexp_bound[i] <- max(rel_preB) + + + + } + + res +} + +############################################################ +# 2. Sweep across param_grid × r_vals +############################################################ + +simulate_bc_prob_4methods <- function(param_grid, + r_vals, + mpfr_prec = 256, + tol = 1e-12) { + + if (!all(c("max_cat", "ref", "theta_lin", "theta_quad") %in% names(param_grid))) { + stop("param_grid must have columns: max_cat, ref, theta_lin, theta_quad") + } + + out_list <- vector("list", nrow(param_grid)) + + for (cfg_idx in seq_len(nrow(param_grid))) { + cfg <- param_grid[cfg_idx, ] + + res_cfg <- compare_bc_prob_4methods_one( + max_cat = cfg$max_cat, + ref = cfg$ref, + theta_lin = cfg$theta_lin, + theta_quad = cfg$theta_quad, + r_vals = r_vals, + mpfr_prec = mpfr_prec + ) + + res_cfg$config_id <- cfg_idx + res_cfg$max_cat <- cfg$max_cat + res_cfg$ref <- cfg$ref + res_cfg$theta_lin <- cfg$theta_lin + res_cfg$theta_quad <- cfg$theta_quad + + # simple ok flags + res_cfg$ok_direct <- is.finite(res_cfg$err_direct) & (res_cfg$err_direct < tol) + res_cfg$ok_bound <- is.finite(res_cfg$err_bound) & (res_cfg$err_bound < tol) + res_cfg$ok_preexp <- is.finite(res_cfg$err_preexp) & (res_cfg$err_preexp < tol) + res_cfg$ok_preexp_bound <- is.finite(res_cfg$err_preexp_bound) & (res_cfg$err_preexp_bound < tol) + + out_list[[cfg_idx]] <- res_cfg + } + + do.call(rbind, out_list) +} + +############################################################ +# 3. Example broad analysis (you can adjust this) +############################################################ + +param_grid <- expand.grid( + max_cat = c(4, 10), + ref = c(0, 2, 4, 5, 10), + theta_lin = c(-0.5, 0.0, 0.5), + theta_quad = c(-0.2, 0.0, 0.2), + KEEP.OUT.ATTRS = FALSE, + stringsAsFactors = FALSE +) + +r_vals <- seq(-80, 80, length.out = 2001) +tol <- 1e-12 + +sim4 <- simulate_bc_prob_4methods( + param_grid = param_grid, + r_vals = r_vals, + mpfr_prec = 256, + tol = tol +) + +############################################################ +# 4. Summaries: where each method fails, as a function of bound/pow_bound +############################################################ + +df4 <- sim4 %>% + mutate( + abs_bound = abs(bound), + err_direct_cl = pmax(err_direct, 1e-300), + err_bound_cl = pmax(err_bound, 1e-300), + err_preexp_cl = pmax(err_preexp, 1e-300), + err_preexp_bound_cl = pmax(err_preexp_bound, 1e-300), + log_err_direct = log10(err_direct_cl), + log_err_bound = log10(err_bound_cl), + log_err_preexp = log10(err_preexp_cl), + log_err_preexp_bound= log10(err_preexp_bound_cl) + ) + +# Example: failures for each method inside |bound| <= 709 & pow_bound <= 709 +inside <- df4 %>% + filter(abs(bound) <= EXP_BOUND, pow_bound <= EXP_BOUND) + +n_direct_fail <- sum(!inside$ok_direct) +n_bound_fail <- sum(!inside$ok_bound) +n_preexp_fail <- sum(!inside$ok_preexp) +n_preexp_bound_fail <- sum(!inside$ok_preexp_bound) + +cat("\nFailures inside fast region (|bound| <= 709 & pow_bound <= 709):\n") +cat(" direct_unscaled :", n_direct_fail, "\n") +cat(" direct_bound :", n_bound_fail, "\n") +cat(" preexp_unscaled :", n_preexp_fail, "\n") +cat(" preexp_bound (FAST) :", n_preexp_bound_fail, "\n\n") diff --git a/dev/numerical_analyses/bgm_regularordinal_normalization_PL.R b/dev/numerical_analyses/bgm_regularordinal_normalization_PL.R new file mode 100644 index 00000000..8a6219b1 --- /dev/null +++ b/dev/numerical_analyses/bgm_regularordinal_normalization_PL.R @@ -0,0 +1,695 @@ +################################################################################ +# Reference: Numerical stability study for bounded vs. unbounded exponential sums +# Author: [Your Name] +# Date: [YYYY-MM-DD] +# +# Purpose: +# Evaluate and compare four ways to compute the sum +# +# S = 1 + Σ_{c=1..K} exp( m_c + (c+1)*r ) +# +# where r may vary widely. The goal is to identify numerically stable and +# computationally efficient formulations for use in gradient calculations. +# +# Methods compared: +# (1) direct – naive computation using raw exp() +# (2) bounded – stabilized by subtracting a "bound" (i.e., scaled domain) +# (3) preexp – precomputes exp(m_c) and exp(r) to replace repeated calls +# (4) preexp_bound – preexp variant with the same "bound" scaling +# +# For each method, we compute both unscaled and scaled variants where relevant, +# and compare them against a high-precision MPFR reference. +# +# Key insight: +# - For large negative r, preexp can lose precision (tiny multiplicative updates). +# - For large positive r, bounded scaling avoids overflow. +# - The combination (preexp + bound) gives the best general stability. +# +# Output: +# - res: data frame with per-r results and relative errors +# - Diagnostic plots and summary tables for numerical accuracy +################################################################################ + +library(Rmpfr) # for arbitrary precision reference computations + + +################################################################################ +# 1. Core comparison function +################################################################################ +compare_all_methods <- function(K = 5, + r_vals = seq(-10, 10, length.out = 200), + m_vals = NULL, + mpfr_prec = 256) { + # --------------------------------------------------------------------------- + # Parameters: + # K – number of categories (terms in the sum) + # r_vals – vector of r values to evaluate over + # m_vals – optional vector of m_c values; random if NULL + # mpfr_prec – bits of precision for the high-precision reference + # + # Returns: + # A data.frame containing per-r computed values, reference values, + # relative errors, and failure flags. + # --------------------------------------------------------------------------- + + if (is.null(m_vals)) m_vals <- runif(K, -1, 1) + + results <- data.frame( + r = r_vals, + direct = NA_real_, + bounded = NA_real_, # scaled-domain computation (exp(-bound) factor) + preexp = NA_real_, + preexp_bound = NA_real_, # scaled-domain computation + ref = NA_real_, # unscaled MPFR reference + ref_scaled = NA_real_, # scaled reference + err_direct = NA_real_, + err_bounded = NA_real_, + err_preexp = NA_real_, + err_preexp_bound = NA_real_, + ref_failed_unscaled = FALSE, + ref_failed_scaled = FALSE + ) + + # Loop over all r-values + for (i in seq_along(r_vals)) { + r <- r_vals[i] + bound <- K * r # can be unclipped; use max(0, K*r) for the clipped version + + # --- (0) High-precision MPFR reference ----------------------------------- + r_mp <- mpfr(r, precBits = mpfr_prec) + m_mp <- mpfr(m_vals, precBits = mpfr_prec) + b_mp <- mpfr(bound, precBits = mpfr_prec) + + ref_unscaled_mp <- 1 + sum(exp(m_mp + (1:K) * r_mp)) + ref_scaled_mp <- exp(-b_mp) * ref_unscaled_mp + + # Convert to doubles for inspection + ref_unscaled_num <- asNumeric(ref_unscaled_mp) + ref_scaled_num <- asNumeric(ref_scaled_mp) + results$ref_failed_unscaled[i] <- !is.finite(ref_unscaled_num) + results$ref_failed_scaled[i] <- !is.finite(ref_scaled_num) + results$ref[i] <- if (is.finite(ref_unscaled_num)) ref_unscaled_num else NA_real_ + results$ref_scaled[i] <- if (is.finite(ref_scaled_num)) ref_scaled_num else NA_real_ + + # --- (1) Direct exponential sum (unscaled) ------------------------------- + results$direct[i] <- 1 + sum(exp(m_vals + (1:K) * r)) + + # --- (2) Current bounded implementation (scaled) ------------------------- + eB <- exp(-bound) + results$bounded[i] <- eB + sum(exp(m_vals + (1:K) * r - bound)) + + # --- (3) Precomputed exp only (unscaled) --------------------------------- + exp_r <- exp(r) + exp_m <- exp(m_vals) + powE <- exp_r + S_pre <- 1.0 + for (c in 1:K) { + S_pre <- S_pre + exp_m[c] * powE + powE <- powE * exp_r + } + results$preexp[i] <- S_pre + + # --- (4) Precomputed exp + bound scaling (scaled) ------------------------ + exp_r <- exp(r) + exp_m <- exp(m_vals) + powE <- exp_r + S_preB <- eB + for (c in 1:K) { + S_preB <- S_preB + exp_m[c] * powE * eB + powE <- powE * exp_r + } + results$preexp_bound[i] <- S_preB + + # --- (5) Relative errors vs references ----------------------------------- + # Unscaled methods + for (m in c("direct", "preexp")) { + val <- results[[m]][i] + if (is.finite(val)) { + val_mp <- mpfr(val, precBits = mpfr_prec) + err_mp <- abs((val_mp - ref_unscaled_mp) / ref_unscaled_mp) + results[[paste0("err_", m)]][i] <- asNumeric(err_mp) + } + } + + # Scaled methods + for (m in c("bounded", "preexp_bound")) { + val <- results[[m]][i] + if (is.finite(val)) { + val_mp <- mpfr(val, precBits = mpfr_prec) + err_mp <- abs((val_mp - ref_scaled_mp) / ref_scaled_mp) + results[[paste0("err_", m)]][i] <- asNumeric(err_mp) + } + } + } + + msg_a <- mean(results$ref_failed_unscaled) + msg_b <- mean(results$ref_failed_scaled) + message(sprintf("Ref (unscaled) non-finite in %.1f%%; Ref (scaled) non-finite in %.1f%% of r-values", + 100 * msg_a, 100 * msg_b)) + results +} + + +################################################################################ +# 2. Plotting: log-scale accuracy with failure marking +################################################################################ +plot_errors <- function(res) { + err_cols <- c("err_bounded", "err_direct", "err_preexp", "err_preexp_bound") + cols <- c("gray", "black", "red", "blue") + names(cols) <- err_cols + + # Compute a robust ylim (1st–99th percentile) + finite_vals <- unlist(res[err_cols]) + finite_vals <- finite_vals[is.finite(finite_vals) & finite_vals > 0] + if (length(finite_vals) > 0) { + lower <- quantile(finite_vals, 0.01, na.rm = TRUE) + upper <- quantile(finite_vals, 0.99, na.rm = TRUE) + ylim <- c(lower / 10, upper * 10) + } else { + ylim <- c(1e-20, 1e-12) + } + + # Baseline curve: bounded + plot(res$r, res$err_bounded, type = "l", log = "y", + col = cols["err_bounded"], lwd = 2, + ylim = ylim, + xlab = "r", ylab = "Relative error", + main = "Accuracy and failure regions") + + # Add other methods + for (e in setdiff(err_cols, "err_bounded")) + lines(res$r, res[[e]], col = cols[e], lwd = 2) + + abline(h = .Machine$double.eps, col = "darkgray", lty = 2) + + legend("bottomright", + legend = c("Current bounded", "Direct exp", + "Preexp only", "Preexp + bound"), + col = cols, lwd = 2, bty = "n") + + # Mark numeric failures + for (e in err_cols) { + bad <- which(!is.finite(res[[e]]) | res[[e]] <= 0) + if (length(bad) > 0) + points(res$r[bad], rep(ylim[1], length(bad)), + pch = 21, col = cols[e], bg = cols[e], cex = 0.6) + } + + legend("bottomleft", legend = "dots = 0/Inf/NaN failures", bty = "n") +} + + +################################################################################ +# 3. Summarize accuracy across r +################################################################################ +summarize_accuracy <- function(res) { + err_cols <- c("err_direct", "err_bounded", "err_preexp", "err_preexp_bound") + + summary <- data.frame( + Method = c("Direct exp", "Current bounded", + "Preexp only", "Preexp + bound"), + Mean_error = NA_real_, + Median_error = NA_real_, + Max_error = NA_real_, + Finite_fraction = NA_real_, + Zero_or_Inf_fraction = NA_real_ + ) + + for (j in seq_along(err_cols)) { + e <- res[[err_cols[j]]] + finite_mask <- is.finite(e) & e > 0 + summary$Mean_error[j] <- mean(e[finite_mask], na.rm = TRUE) + summary$Median_error[j] <- median(e[finite_mask], na.rm = TRUE) + summary$Max_error[j] <- max(e[finite_mask], na.rm = TRUE) + summary$Finite_fraction[j] <- mean(finite_mask) + summary$Zero_or_Inf_fraction[j] <- 1 - mean(finite_mask) + } + + summary +} + + +################################################################################ +# 4. Alternate jitter plot for fine-scale comparison +################################################################################ +plot_errors_jitter <- function(res, offset_for_visibility = TRUE) { + err_cols <- c("err_bounded", "err_direct", "err_preexp", "err_preexp_bound") + cols <- c("gray", "black", "red", "blue") + + message("Plotting columns:") + for (i in seq_along(err_cols)) + message(sprintf(" %-15s -> %s", err_cols[i], cols[i])) + + offset_factor <- if (offset_for_visibility) c(1, 5, 100, 1e4) else rep(1, 4) + + finite_vals <- unlist(res[err_cols]) + finite_vals <- finite_vals[is.finite(finite_vals) & finite_vals > 0] + if (length(finite_vals) > 0) { + lower <- quantile(finite_vals, 0.01, na.rm = TRUE) + upper <- quantile(finite_vals, 0.99, na.rm = TRUE) + ylim <- c(lower / 10, upper * 10) + } else ylim <- c(1e-20, 1e-12) + + plot(res$r, res$err_bounded * offset_factor[1], + type = "l", log = "y", lwd = 2, col = cols[1], + ylim = ylim, + xlab = "r", ylab = "Relative error", + main = "Accuracy (offset for visibility)") + + for (j in 2:length(err_cols)) + lines(res$r, res[[err_cols[j]]] * offset_factor[j], col = cols[j], lwd = 2) + + abline(h = .Machine$double.eps, col = "darkgray", lty = 2) + legend("bottomright", + legend = c("Current bounded", "Direct exp", "Preexp only", "Preexp + bound"), + col = cols, lwd = 2) +} + + +################################################################################ +# 5. Example usage +################################################################################ +# Run test for a moderate K and r-range. +# Expand range (e.g. seq(-100, 80, 1)) to probe overflow/underflow limits. +# res <- compare_all_methods(K = 10, r_vals = seq(-71, 71, length.out = 1e4)) +# +# # Plot and summarize +# plot_errors(res) +# summary_table <- summarize_accuracy(res) +# print(summary_table, digits = 3) +# plot_errors_jitter(res) # optional visualization with offsets +################################################################################ + + +################################################################################ +# 6. Ratio stability check (direct vs preexp) × (bound vs clipped) +################################################################################ +compare_prob_ratios <- function(K = 5, + r_vals = seq(-20, 20, length.out = 200), + m_vals = NULL, + mpfr_prec = 256) { + + if (!requireNamespace("Rmpfr", quietly = TRUE)) + stop("Please install Rmpfr: install.packages('Rmpfr')") + + if (is.null(m_vals)) m_vals <- runif(K, -1, 1) + + res <- data.frame( + r = numeric(length(r_vals)), + err_direct_bound = numeric(length(r_vals)), + err_direct_clip = numeric(length(r_vals)), + err_preexp_bound = numeric(length(r_vals)), + err_preexp_clip = numeric(length(r_vals)) + ) + + for (i in seq_along(r_vals)) { + r <- r_vals[i] + b_raw <- K * r + b_clip <- max(0, b_raw) + + # --- High-precision reference --------------------------------------------- + r_mp <- Rmpfr::mpfr(r, precBits = mpfr_prec) + m_mp <- Rmpfr::mpfr(m_vals, precBits = mpfr_prec) + exp_terms_ref <- exp(m_mp + (1:K) * r_mp) + denom_ref <- 1 + sum(exp_terms_ref) + p_ref_num <- as.numeric(exp_terms_ref / denom_ref) + + # --- (1) Direct, un-clipped bound ---------------------------------------- + exp_terms_dB <- exp(m_vals + (1:K) * r - b_raw) + denom_dB <- exp(-b_raw) + sum(exp_terms_dB) + p_dB <- exp_terms_dB / denom_dB + res$err_direct_bound[i] <- max(abs(p_dB - p_ref_num) / p_ref_num) + + # --- (2) Direct, clipped bound ------------------------------------------- + exp_terms_dC <- exp(m_vals + (1:K) * r - b_clip) + denom_dC <- exp(-b_clip) + sum(exp_terms_dC) + p_dC <- exp_terms_dC / denom_dC + res$err_direct_clip[i] <- max(abs(p_dC - p_ref_num) / p_ref_num) + + # --- (3) Preexp, un-clipped bound --------------------------------------- + eR <- exp(r) + eM <- exp(m_vals) + eB <- exp(-b_raw) + powE <- eR + S_preB <- eB + terms_preB <- numeric(K) + for (c in 1:K) { + term <- eM[c] * powE * eB + terms_preB[c] <- term + S_preB <- S_preB + term + powE <- powE * eR + } + p_preB <- terms_preB / S_preB + res$err_preexp_bound[i] <- max(abs(p_preB - p_ref_num) / p_ref_num) + + # --- (4) Preexp, clipped bound ------------------------------------------ + eR <- exp(r) + eM <- exp(m_vals) + eB <- exp(-b_clip) + powE <- eR + S_preC <- eB + terms_preC <- numeric(K) + for (c in 1:K) { + term <- eM[c] * powE * eB + terms_preC[c] <- term + S_preC <- S_preC + term + powE <- powE * eR + } + p_preC <- terms_preC / S_preC + res$err_preexp_clip[i] <- max(abs(p_preC - p_ref_num) / p_ref_num) + + res$r[i] <- r + } + + return(res) +} + + +################################################################################ +# 7. Example usage: compare probability ratio stability +################################################################################ + +# K <- 10 +# r_vals <- seq(-75, 75, length.out = 1e4) +# set.seed(123) +# m_vals <- runif(K, -1, 1) +# +# res_ratio <- compare_prob_ratios(K = K, r_vals = r_vals, m_vals = m_vals) +# +# eps <- .Machine$double.eps +# plot(res_ratio$r, pmax(res_ratio$err_direct_bound, eps), +# type = "l", log = "y", lwd = 2, col = "red", +# xlab = "r", ylab = "Relative error (vs MPFR reference)", +# main = "Numerical stability of p_c ratio computations — 4 variants") +# +# lines(res_ratio$r, pmax(res_ratio$err_direct_clip, eps), col = "blue", lwd = 2) +# lines(res_ratio$r, pmax(res_ratio$err_preexp_bound, eps), col = "orange", lwd = 2) +# lines(res_ratio$r, pmax(res_ratio$err_preexp_clip, eps), col = "purple", lwd = 2) +# +# abline(h = .Machine$double.eps, col = "darkgray", lty = 2) +# legend("top", +# legend = c("Direct + Bound", "Direct + Clipped Bound", +# "Preexp + Bound", "Preexp + Clipped Bound"), +# col = c("red", "blue", "orange", "purple"), +# lwd = 2, bty = "n") +# +# abline(v = -70) +# abline(v = 70) +# +# # Summarize numeric accuracy +# summary_df <- data.frame( +# Method = c("Direct + Bound", "Direct + Clipped Bound", +# "Preexp + Bound", "Preexp + Clipped Bound"), +# Mean_error = c(mean(res_ratio$err_direct_bound, na.rm = TRUE), +# mean(res_ratio$err_direct_clip, na.rm = TRUE), +# mean(res_ratio$err_preexp_bound, na.rm = TRUE), +# mean(res_ratio$err_preexp_clip, na.rm = TRUE)), +# Median_error = c(median(res_ratio$err_direct_bound, na.rm = TRUE), +# median(res_ratio$err_direct_clip, na.rm = TRUE), +# median(res_ratio$err_preexp_bound, na.rm = TRUE), +# median(res_ratio$err_preexp_clip, na.rm = TRUE)), +# Max_error = c(max(res_ratio$err_direct_bound, na.rm = TRUE), +# max(res_ratio$err_direct_clip, na.rm = TRUE), +# max(res_ratio$err_preexp_bound, na.rm = TRUE), +# max(res_ratio$err_preexp_clip, na.rm = TRUE)) +# ) +# print(summary_df, digits = 3) +################################################################################ + +############################################################ +# Blume–Capel probabilities: +# Numerical comparison of FAST vs SAFE methods +# +# Objective +# --------- +# For a single Blume–Capel configuration (max_cat, ref, theta_lin, theta_quad), +# and a grid of residual scores r, we compare +# +# p_s(r) ∝ exp( theta_part(s) + s * r ), s = 0..max_cat +# +# with +# +# theta_part(s) = theta_lin * (s - ref) + theta_quad * (s - ref)^2 +# +# computed three ways: +# +# (1) MPFR reference softmax (high precision) +# (2) SAFE : double, direct exponentials with bound (subtract M(r)) +# (3) FAST : double, preexp(theta_part) + power chain for exp(s*r - M(r)) +# +# We record, for each r: +# +# - numeric bound M(r) = max_s [theta_part(s) + s * r] +# - pow_bound = max_cat * r - M(r) +# - max relative error of SAFE +# - max relative error of FAST +# +# No fallbacks, no patching of non-finite values: we let under/overflow +# show up as Inf/NaN in the errors and inspect those. +############################################################ + +library(Rmpfr) # for high-precision reference + +############################################################ +# 1. Reference probabilities using MPFR +############################################################ + +bc_prob_ref_mpfr <- function(max_cat, ref, theta_lin, theta_quad, + r_vals, + mpfr_prec = 256) { + # categories and centered scores + s_vals <- 0:max_cat + c_vals <- s_vals - ref + + # MPFR parameters + tl <- mpfr(theta_lin, precBits = mpfr_prec) + tq <- mpfr(theta_quad, precBits = mpfr_prec) + s_mp <- mpfr(s_vals, precBits = mpfr_prec) + c_mp <- mpfr(c_vals, precBits = mpfr_prec) + + n_r <- length(r_vals) + n_s <- length(s_vals) + + P_ref <- matrix(NA_real_, nrow = n_r, ncol = n_s) + + for (i in seq_len(n_r)) { + r_mp <- mpfr(r_vals[i], precBits = mpfr_prec) + + # exponent(s) = theta_part(s) + s * r + term_mp <- tl * c_mp + tq * c_mp * c_mp + s_mp * r_mp + + # numeric bound M(r) + M_num <- max(asNumeric(term_mp)) + M_mp <- mpfr(M_num, precBits = mpfr_prec) + + # scaled numerators + num_mp <- exp(term_mp - M_mp) + Z_mp <- sum(num_mp) + p_mp <- num_mp / Z_mp + + P_ref[i, ] <- asNumeric(p_mp) + } + + P_ref +} + +############################################################ +# 2. SAFE probabilities (double, direct + bound) +############################################################ + +bc_prob_safe <- function(max_cat, ref, theta_lin, theta_quad, + r_vals) { + s_vals <- 0:max_cat + c_vals <- s_vals - ref + + theta_part <- theta_lin * c_vals + theta_quad * c_vals^2 + + n_r <- length(r_vals) + n_s <- length(s_vals) + + P_safe <- matrix(NA_real_, nrow = n_r, ncol = n_s) + bound <- numeric(n_r) + + for (i in seq_len(n_r)) { + r <- r_vals[i] + + exps <- theta_part + s_vals * r + M <- max(exps) + bound[i] <- M + + numer <- exp(exps - M) + denom <- sum(numer) + + # no fallback here; denom can be 0 or Inf + P_safe[i, ] <- numer / denom + } + + list( + probs = P_safe, + bound = bound + ) +} + +############################################################ +# 3. FAST probabilities (double, preexp + power chain) +############################################################ + +bc_prob_fast <- function(max_cat, ref, theta_lin, theta_quad, + r_vals) { + s_vals <- 0:max_cat + c_vals <- s_vals - ref + + theta_part <- theta_lin * c_vals + theta_quad * c_vals^2 + exp_theta <- exp(theta_part) + + n_r <- length(r_vals) + n_s <- length(s_vals) + + P_fast <- matrix(NA_real_, nrow = n_r, ncol = n_s) + bound <- numeric(n_r) + pow_bound <- numeric(n_r) + + for (i in seq_len(n_r)) { + r <- r_vals[i] + + # exponents before scaling + exps <- theta_part + s_vals * r + M <- max(exps) + bound[i] <- M + + # pow_bound = max_s (s*r - M) attained at s = max_cat + pow_bound[i] <- max_cat * r - M + + eR <- exp(r) + pow <- exp(-M) + + numer <- numeric(n_s) + denom <- 0 + + for (j in seq_len(n_s)) { + numer[j] <- exp_theta[j] * pow + denom <- denom + numer[j] + pow <- pow * eR + } + + # again: no fallback; denom can be 0/Inf + P_fast[i, ] <- numer / denom + } + + list( + probs = P_fast, + bound = bound, + pow_bound = pow_bound + ) +} + +############################################################ +# 4. Core comparison function (one BC config) +############################################################ + +compare_bc_prob_methods <- function(max_cat = 4, + ref = 2, + theta_lin = 0.0, + theta_quad = 0.0, + r_vals = seq(-20, 20, length.out = 200), + mpfr_prec = 256) { + # MPFR reference + P_ref <- bc_prob_ref_mpfr( + max_cat = max_cat, + ref = ref, + theta_lin = theta_lin, + theta_quad = theta_quad, + r_vals = r_vals, + mpfr_prec = mpfr_prec + ) + + # SAFE + safe_res <- bc_prob_safe( + max_cat = max_cat, + ref = ref, + theta_lin = theta_lin, + theta_quad = theta_quad, + r_vals = r_vals + ) + P_safe <- safe_res$probs + bound_safe <- safe_res$bound + + # FAST + fast_res <- bc_prob_fast( + max_cat = max_cat, + ref = ref, + theta_lin = theta_lin, + theta_quad = theta_quad, + r_vals = r_vals + ) + P_fast <- fast_res$probs + bound_fast <- fast_res$bound + pow_bound <- fast_res$pow_bound + + stopifnot(all.equal(bound_safe, bound_fast)) + + n_r <- length(r_vals) + + res <- data.frame( + r = r_vals, + bound = bound_fast, + pow_bound = pow_bound, + err_safe = NA_real_, + err_fast = NA_real_ + ) + + for (i in seq_len(n_r)) { + p_ref <- P_ref[i, ] + p_safe <- P_safe[i, ] + p_fast <- P_fast[i, ] + + # max relative error vs MPFR reference + # (this is exactly in the spirit of compare_prob_ratios) + res$err_safe[i] <- max(abs(p_safe - p_ref) / p_ref) + res$err_fast[i] <- max(abs(p_fast - p_ref) / p_ref) + } + + res +} + +############################################################ +# 5. Example usage +############################################################ + +# Example: small BC variable +# max_cat <- 4 +# ref <- 2 +# theta_lin <- 0.3 +# theta_quad <- -0.1 +# r_vals <- seq(-80, 80, length.out = 2000) +# +# res_bc <- compare_bc_prob_methods( +# max_cat = max_cat, +# ref = ref, +# theta_lin = theta_lin, +# theta_quad = theta_quad, +# r_vals = r_vals, +# mpfr_prec = 256 +# ) +# +# # Quick inspection: log10 errors +# eps <- .Machine$double.eps +# plot(res_bc$r, pmax(res_bc$err_safe, eps), +# type = "l", log = "y", col = "black", lwd = 2, +# xlab = "r", ylab = "Relative error (vs MPFR)", +# main = "Blume–Capel probabilities: SAFE vs FAST") +# lines(res_bc$r, pmax(res_bc$err_fast, eps), col = "red", lwd = 2) +# abline(h = eps, col = "darkgray", lty = 2) +# legend("topright", +# legend = c("SAFE (direct + bound)", "FAST (preexp + power chain)"), +# col = c("black", "red"), +# lwd = 2, bty = "n") +# +# # You can then condition on bound/pow_bound just like in the +# # Blume–Capel normalization script to decide where FAST is safe. +############################################################ + + + + + diff --git a/inst/REFERENCES.bib b/inst/REFERENCES.bib index c824fadb..b34e1fde 100644 --- a/inst/REFERENCES.bib +++ b/inst/REFERENCES.bib @@ -86,10 +86,9 @@ @article{MarsmanVandenBerghHaslbeck_2024 @article{MarsmanWaldorpSekulovskiHaslbeck_2024, author = {Marsman, M. and Waldorp, L. J. and Sekulovski, N. and Haslbeck, J. M. B.}, - journal = {Retrieved from https://osf.io/preprints/osf/f4pk9}, + journal = {Psychometrika}, title = {Bayes factor tests for group differences in ordinal and binary graphical models.}, - note = {OSF preprint}, - year = {2024}} + year = {in press}} @article{McNallyEtAl_2015, author = {{McNally}, R. J. and Robinaugh, D. J. and Wu, G. W. Y. and Wang, L. and Deserno, M. K. and Borsboom, D.}, diff --git a/inst/generate_makevars_sources.R b/inst/generate_makevars_sources.R new file mode 100644 index 00000000..cb238c47 --- /dev/null +++ b/inst/generate_makevars_sources.R @@ -0,0 +1,26 @@ +cpp <- list.files( + "src", + pattern = "\\.cpp$", + recursive = TRUE, + full.names = TRUE +) + +# strip leading "src/" (MacOs/ Linux) or leading "src\\" (Windows) +cpp <- sub("^src[\\\\/]", "", cpp) + +con <- file(file.path("src", "sources.mk"), open = "w") + +writeLines(c( + "# ------------------------------------------------------------------", + "# THIS FILE IS AUTO-GENERATED - DO NOT EDIT", + "# Generated by configure", + "# To add C++ code, place .cpp files anywhere under src/", + "# ------------------------------------------------------------------", + "SOURCES = \\" +), con) + +writeLines(paste0(" ", cpp, " \\"), con) +writeLines("", con) + +close(con) + diff --git a/man/bgm.Rd b/man/bgm.Rd index b0982d9a..4a3be522 100644 --- a/man/bgm.Rd +++ b/man/bgm.Rd @@ -242,7 +242,7 @@ category responses. Assume a baseline category (e.g., a “neutral” response) and score responses by distance from this baseline. Category thresholds are modeled as: -\deqn{\mu_{c} = \alpha \cdot c + \beta \cdot (c - b)^2} +\deqn{\mu_{c} = \alpha \cdot (c-b) + \beta \cdot (c - b)^2} where: \itemize{ @@ -257,6 +257,8 @@ where: } \item \eqn{b}: baseline category } +Accordingly, pairwise interactions between Blume-Capel variables are modeled +in terms of \eqn{c-b} scores. } \section{Edge Selection}{ diff --git a/man/bgms-package.Rd b/man/bgms-package.Rd index aaad0021..06c1fe09 100644 --- a/man/bgms-package.Rd +++ b/man/bgms-package.Rd @@ -69,6 +69,7 @@ For tutorials and worked examples, see: Useful links: \itemize{ \item \url{https://Bayesian-Graphical-Modelling-Lab.github.io/bgms/} + \item \url{https://github.com/Bayesian-Graphical-Modelling-Lab/bgms} \item Report bugs at \url{https://github.com/Bayesian-Graphical-Modelling-Lab/bgms/issues} } diff --git a/man/mrfSampler.Rd b/man/mrfSampler.Rd index 8c50c434..d4b233ce 100644 --- a/man/mrfSampler.Rd +++ b/man/mrfSampler.Rd @@ -11,7 +11,7 @@ mrfSampler( interactions, thresholds, variable_type = "ordinal", - reference_category, + baseline_category, iter = 1000 ) } @@ -43,8 +43,8 @@ for each variable separately. Currently, bgm supports ``ordinal'' and ``blume-capel''. Binary variables are automatically treated as ``ordinal’’. Defaults to \code{variable_type = "ordinal"}.} -\item{reference_category}{An integer vector of length \code{no_variables} specifying the -reference_category category that is used for the Blume-Capel model (details below). +\item{baseline_category}{An integer vector of length \code{no_variables} specifying the +baseline_category category that is used for the Blume-Capel model (details below). Can be any integer value between \code{0} and \code{no_categories} (or \code{no_categories[i]}).} @@ -71,7 +71,7 @@ useful for any type of ordinal variable and gives the user the most freedom in specifying their model. The Blume-Capel option is specifically designed for ordinal variables that -have a special type of reference_category category, such as the neutral +have a special type of baseline_category category, such as the neutral category in a Likert scale. The Blume-Capel model specifies the following quadratic model for the threshold parameters: \deqn{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}{{\mu_{\text{c}} = \alpha \times \text{c} + \beta \times (\text{c} - \text{r})^2,}} @@ -81,8 +81,8 @@ across categories (increasing threshold values if \eqn{\alpha > 0}{\alpha > 0} and decreasing threshold values if \eqn{\alpha <0}{\alpha <0}), if \eqn{\beta < 0}{\beta < 0}, it offers an increasing penalty for responding in a category further away from the -reference_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a -preference for responding in the reference_category category. +baseline_category category r, while \eqn{\beta > 0}{\beta > 0} suggests a +preference for responding in the baseline_category category. } \examples{ # Generate responses from a network of five binary and ordinal variables. @@ -95,11 +95,13 @@ Interactions[2, 1] = Interactions[4, 1] = Interactions[3, 2] = Interactions = Interactions + t(Interactions) Thresholds = matrix(0, nrow = no_variables, ncol = max(no_categories)) -x = mrfSampler(no_states = 1e3, - no_variables = no_variables, - no_categories = no_categories, - interactions = Interactions, - thresholds = Thresholds) +x = mrfSampler( + no_states = 1e3, + no_variables = no_variables, + no_categories = no_categories, + interactions = Interactions, + thresholds = Thresholds +) # Generate responses from a network of 2 ordinal and 3 Blume-Capel variables. no_variables = 5 @@ -116,12 +118,14 @@ Thresholds[, 2] = -1 Thresholds[3, ] = sort(-abs(rnorm(4)), decreasing = TRUE) Thresholds[5, ] = sort(-abs(rnorm(4)), decreasing = TRUE) -x = mrfSampler(no_states = 1e3, - no_variables = no_variables, - no_categories = no_categories, - interactions = Interactions, - thresholds = Thresholds, - variable_type = c("b","b","o","b","o"), - reference_category = 2) +x = mrfSampler( + no_states = 1e3, + no_variables = no_variables, + no_categories = no_categories, + interactions = Interactions, + thresholds = Thresholds, + variable_type = c("b", "b", "o", "b", "o"), + baseline_category = 2 +) } diff --git a/src/Makevars.in b/src/Makevars.in index 6a26be9a..ea013a0f 100644 --- a/src/Makevars.in +++ b/src/Makevars.in @@ -1,5 +1,9 @@ CXX_STD = CXX20 -PKG_CPPFLAGS = @RCPP_PARALLEL_CPPFLAGS@ -DARMA_NO_DEBUG +include sources.mk + +OBJECTS = $(SOURCES:.cpp=.o) + +PKG_CPPFLAGS = @RCPP_PARALLEL_CPPFLAGS@ -DARMA_NO_DEBUG -I. PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) @RCPP_PARALLEL_LIBS@ diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index f3fc9664..084816e8 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -101,38 +101,6 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } -// get_explog_switch -Rcpp::String get_explog_switch(); -RcppExport SEXP _bgms_get_explog_switch() { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - rcpp_result_gen = Rcpp::wrap(get_explog_switch()); - return rcpp_result_gen; -END_RCPP -} -// rcpp_ieee754_exp -Rcpp::NumericVector rcpp_ieee754_exp(Rcpp::NumericVector x); -RcppExport SEXP _bgms_rcpp_ieee754_exp(SEXP xSEXP) { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< Rcpp::NumericVector >::type x(xSEXP); - rcpp_result_gen = Rcpp::wrap(rcpp_ieee754_exp(x)); - return rcpp_result_gen; -END_RCPP -} -// rcpp_ieee754_log -Rcpp::NumericVector rcpp_ieee754_log(Rcpp::NumericVector x); -RcppExport SEXP _bgms_rcpp_ieee754_log(SEXP xSEXP) { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< Rcpp::NumericVector >::type x(xSEXP); - rcpp_result_gen = Rcpp::wrap(rcpp_ieee754_log(x)); - return rcpp_result_gen; -END_RCPP -} // sample_omrf_gibbs IntegerMatrix sample_omrf_gibbs(int no_states, int no_variables, IntegerVector no_categories, NumericMatrix interactions, NumericMatrix thresholds, int iter); RcppExport SEXP _bgms_sample_omrf_gibbs(SEXP no_statesSEXP, SEXP no_variablesSEXP, SEXP no_categoriesSEXP, SEXP interactionsSEXP, SEXP thresholdsSEXP, SEXP iterSEXP) { @@ -150,8 +118,8 @@ BEGIN_RCPP END_RCPP } // sample_bcomrf_gibbs -IntegerMatrix sample_bcomrf_gibbs(int no_states, int no_variables, IntegerVector no_categories, NumericMatrix interactions, NumericMatrix thresholds, StringVector variable_type, IntegerVector reference_category, int iter); -RcppExport SEXP _bgms_sample_bcomrf_gibbs(SEXP no_statesSEXP, SEXP no_variablesSEXP, SEXP no_categoriesSEXP, SEXP interactionsSEXP, SEXP thresholdsSEXP, SEXP variable_typeSEXP, SEXP reference_categorySEXP, SEXP iterSEXP) { +IntegerMatrix sample_bcomrf_gibbs(int no_states, int no_variables, IntegerVector no_categories, NumericMatrix interactions, NumericMatrix thresholds, StringVector variable_type, IntegerVector baseline_category, int iter); +RcppExport SEXP _bgms_sample_bcomrf_gibbs(SEXP no_statesSEXP, SEXP no_variablesSEXP, SEXP no_categoriesSEXP, SEXP interactionsSEXP, SEXP thresholdsSEXP, SEXP variable_typeSEXP, SEXP baseline_categorySEXP, SEXP iterSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -161,9 +129,29 @@ BEGIN_RCPP Rcpp::traits::input_parameter< NumericMatrix >::type interactions(interactionsSEXP); Rcpp::traits::input_parameter< NumericMatrix >::type thresholds(thresholdsSEXP); Rcpp::traits::input_parameter< StringVector >::type variable_type(variable_typeSEXP); - Rcpp::traits::input_parameter< IntegerVector >::type reference_category(reference_categorySEXP); + Rcpp::traits::input_parameter< IntegerVector >::type baseline_category(baseline_categorySEXP); Rcpp::traits::input_parameter< int >::type iter(iterSEXP); - rcpp_result_gen = Rcpp::wrap(sample_bcomrf_gibbs(no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)); + rcpp_result_gen = Rcpp::wrap(sample_bcomrf_gibbs(no_states, no_variables, no_categories, interactions, thresholds, variable_type, baseline_category, iter)); + return rcpp_result_gen; +END_RCPP +} +// sample_ggm +Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type); +RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); + Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); + Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP); + Rcpp::traits::input_parameter< const int >::type no_chains(no_chainsSEXP); + Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); + Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); + Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); + Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)); return rcpp_result_gen; END_RCPP } @@ -185,11 +173,9 @@ END_RCPP static const R_CallMethodDef CallEntries[] = { {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 36}, {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 34}, - {"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0}, - {"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1}, - {"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1}, {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8}, + {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 10}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/bgm_helper.cpp b/src/bgm/bgm_helper.cpp similarity index 99% rename from src/bgm_helper.cpp rename to src/bgm/bgm_helper.cpp index 8f37cda9..a0696518 100644 --- a/src/bgm_helper.cpp +++ b/src/bgm/bgm_helper.cpp @@ -1,6 +1,6 @@ #include -#include "bgm_helper.h" -#include "common_helpers.h" +#include "bgm/bgm_helper.h" +#include "utils/common_helpers.h" diff --git a/src/bgm_helper.h b/src/bgm/bgm_helper.h similarity index 98% rename from src/bgm_helper.h rename to src/bgm/bgm_helper.h index 0cafba85..fe6e890b 100644 --- a/src/bgm_helper.h +++ b/src/bgm/bgm_helper.h @@ -1,7 +1,7 @@ #pragma once #include -#include "rng_utils.h" +#include "rng/rng_utils.h" // Vectorize main_effect matrix arma::vec vectorize_main_effects_bgm( diff --git a/src/bgm_logp_and_grad.cpp b/src/bgm/bgm_logp_and_grad.cpp similarity index 75% rename from src/bgm_logp_and_grad.cpp rename to src/bgm/bgm_logp_and_grad.cpp index a6c21b64..c0ef86da 100644 --- a/src/bgm_logp_and_grad.cpp +++ b/src/bgm/bgm_logp_and_grad.cpp @@ -1,8 +1,9 @@ #include -#include "bgm_helper.h" -#include "bgm_logp_and_grad.h" -#include "common_helpers.h" -#include "explog_switch.h" +#include "bgm/bgm_helper.h" +#include "bgm/bgm_logp_and_grad.h" +#include "utils/common_helpers.h" +#include "math/explog_switch.h" +#include "utils/variable_helpers.h" @@ -61,6 +62,7 @@ double log_pseudoposterior_main_effects_component ( }; const int num_cats = num_categories(variable); + arma::vec bound = num_cats * residual_matrix.col(variable); // numerical bound vector if (is_ordinal_variable(variable)) { // Prior contribution + sufficient statistic @@ -68,22 +70,12 @@ double log_pseudoposterior_main_effects_component ( log_posterior += value * counts_per_category(category + 1, variable); log_posterior += log_beta_prior (value); - // Vectorized likelihood contribution - // For each person, we compute the unnormalized log-likelihood denominator: - // denom = exp (-bound) + sum_c exp (main_effect_param_c + (c+1) * residual_score - bound) - // Where: - // - residual_score is the summed interaction score excluding the variable itself - // - bound = num_cats * residual_score (for numerical stability) - // - main_effect_param_c is the main_effect parameter for category c (0-based) - arma::vec residual_score = residual_matrix.col (variable); // rest scores for all persons - arma::vec bound = num_cats * residual_score; // numerical bound vector - arma::vec denom = ARMA_MY_EXP (-bound); // initialize with base term + arma::vec residual_score = residual_matrix.col (variable); // rest scores for all persons arma::vec main_effect_param = main_effects.row (variable).cols (0, num_cats - 1).t (); // main_effect parameters - for (int cat = 0; cat < num_cats; cat++) { - arma::vec exponent = main_effect_param(cat) + (cat + 1) * residual_score - bound; // exponent per person - denom += ARMA_MY_EXP (exponent); // accumulate exp terms - } + arma::vec denom = compute_denom_ordinal( + residual_score, main_effect_param, bound + ); // We then compute the total log-likelihood contribution as: // log_posterior -= bound + log (denom), summed over all persons @@ -98,24 +90,14 @@ double log_pseudoposterior_main_effects_component ( log_posterior += value * blume_capel_stats(parameter, variable); log_posterior += log_beta_prior(value); - // Vectorized likelihood contribution - // For each person, we compute the unnormalized log-likelihood denominator: - // denom = sum_c exp (θ_lin * c + θ_quad * (c - ref)^2 + c * residual_score - bound) - // Where: - // - θ_lin, θ_quad are linear and quadratic main_effects - // - ref is the reference category (used for centering) - // - bound = num_cats * residual_score (stabilizes exponentials) arma::vec residual_score = residual_matrix.col(variable); // rest scores for all persons - arma::vec bound = num_cats * residual_score; // numerical bound vector arma::vec denom(num_persons, arma::fill::zeros); // initialize denominator - for (int cat = 0; cat <= num_cats; cat++) { - int centered = cat - ref; // centered category - double quad_term = quadratic_main_effect * centered * centered; // precompute quadratic term - double lin_term = linear_main_effect * cat; // precompute linear term - arma::vec exponent = lin_term + quad_term + cat * residual_score - bound; - denom += ARMA_MY_EXP (exponent); // accumulate over categories - } + denom = compute_denom_blume_capel( + residual_score, linear_main_effect, quadratic_main_effect, ref, + num_cats, bound + ); + // The final log-likelihood contribution is then: // log_posterior -= bound + log (denom), summed over all persons @@ -175,37 +157,33 @@ double log_pseudoposterior_interactions_component ( double log_pseudo_posterior = 2.0 * pairwise_effects(var1, var2) * pairwise_stats(var1, var2); for (int var : {var1, var2}) { - int num_categories_var = num_categories (var); + int num_cats = num_categories (var); // Compute rest score: contribution from other variables - arma::vec residual_scores = observations * pairwise_effects.col (var); - arma::vec bounds = arma::max (residual_scores, arma::zeros (num_observations)) * num_categories_var; + arma::vec residual_score = observations * pairwise_effects.col (var); arma::vec denominator = arma::zeros (num_observations); + arma::vec bound = num_cats * residual_score; // numerical bound vector if (is_ordinal_variable (var)) { - // Ordinal variable: denominator includes exp (-bounds) + arma::vec main_effect_param = main_effects.row (var).cols (0, num_cats - 1).t (); // main_effect parameters - denominator += ARMA_MY_EXP (-bounds); - for (int category = 0; category < num_categories_var; category++) { - arma::vec exponent = main_effects (var, category) + (category + 1) * residual_scores - bounds; - denominator += ARMA_MY_EXP(exponent); - } + denominator += compute_denom_ordinal( + residual_score, main_effect_param, bound + ); } else { - // Binary/categorical variable: quadratic + linear term - const int ref_cat = baseline_category (var); - for (int category = 0; category <= num_categories_var; category++) { - int centered_cat = category - ref_cat; - double lin_term = main_effects (var, 0) * category; - double quad_term = main_effects (var, 1) * centered_cat * centered_cat; - arma::vec exponent = lin_term + quad_term + category * residual_scores - bounds; - denominator += ARMA_MY_EXP (exponent); - } + const int ref = baseline_category (var); + + denominator = compute_denom_blume_capel( + residual_score, main_effects (var, 0), main_effects (var, 1), ref, + num_cats, bound + ); + } // Subtract log partition function and bounds adjustment log_pseudo_posterior -= arma::accu (ARMA_MY_LOG (denominator)); - log_pseudo_posterior -= arma::accu (bounds); + log_pseudo_posterior -= arma::accu (bound); } // Add Cauchy prior terms for included pairwise effects @@ -311,34 +289,26 @@ double log_pseudoposterior ( // Calculate the log denominators for (int variable = 0; variable < num_variables; variable++) { const int num_cats = num_categories(variable); - arma::vec residual_score = residual_matrix.col (variable); // rest scores for all persons - arma::vec bound = num_cats * residual_score; // numerical bound vector - bound = arma::clamp(bound, 0.0, arma::datum::inf); // only positive bounds to prevent overflow + arma::vec residual_score = residual_matrix.col (variable); // rest scores for all persons + arma::vec bound = num_cats * residual_score; // numerical bound vector - arma::vec denom; + arma::vec denom(num_persons, arma::fill::zeros); if (is_ordinal_variable(variable)) { - denom = ARMA_MY_EXP (-bound); // initialize with base term arma::vec main_effect_param = main_effects.row (variable).cols (0, num_cats - 1).t (); // main_effect parameters for variable - for (int cat = 0; cat < num_cats; cat++) { - arma::vec exponent = main_effect_param(cat) + (cat + 1) * residual_score - bound; // exponent per person - denom += ARMA_MY_EXP (exponent); // accumulate exp terms - } + denom += compute_denom_ordinal( + residual_score, main_effect_param, bound + ); } else { + const int ref = baseline_category(variable); const double lin_effect = main_effects(variable, 0); const double quad_effect = main_effects(variable, 1); - const int ref = baseline_category(variable); - denom.zeros(num_persons); - for (int cat = 0; cat <= num_cats; cat++) { - int centered = cat - ref; // centered category - double quad = quad_effect * centered * centered; // precompute quadratic term - double lin = lin_effect * cat; // precompute linear term - arma::vec exponent = lin + quad + cat * residual_score - bound; - denom += ARMA_MY_EXP (exponent); // accumulate over categories - } + //This updates bound + denom = compute_denom_blume_capel( + residual_score, lin_effect, quad_effect, ref, num_cats, bound + ); } - - log_pseudoposterior -= arma::accu (bound + ARMA_MY_LOG (denom)); // total contribution + log_pseudoposterior -= arma::accu (bound + ARMA_MY_LOG (denom)); // total contribution } return log_pseudoposterior; @@ -346,39 +316,7 @@ double log_pseudoposterior ( -/** - * Computes the gradient of the log-pseudoposterior for main and active pairwise parameters. - * - * Gradient components: - * - Observed sufficient statistics (from counts_per_category, blume_capel_stats, pairwise_stats). - * - Minus expected sufficient statistics (computed via probabilities over categories). - * - Plus gradient contributions from priors: - * * Beta priors on main effects. - * * Cauchy priors on active pairwise effects. - * - * Inputs: - * - main_effects: Matrix of main-effect parameters (variables × categories). - * - pairwise_effects: Symmetric matrix of pairwise interaction strengths. - * - inclusion_indicator: Symmetric binary matrix of active pairwise effects. - * - observations: Matrix of categorical observations (persons × variables). - * - num_categories: Number of categories per variable. - * - counts_per_category: Category counts per variable (for ordinal variables). - * - blume_capel_stats: Sufficient statistics for Blume–Capel variables. - * - baseline_category: Reference categories for Blume–Capel variables. - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - main_alpha, main_beta: Hyperparameters for Beta priors. - * - pairwise_scale: Scale parameter of the Cauchy prior on interactions. - * - pairwise_stats: Sufficient statistics for pairwise effects. - * - residual_matrix: Matrix of residual scores (persons × variables). - * - * Returns: - * - A vector containing the gradient of the log-pseudoposterior with respect to - * all main and active pairwise parameters, in the same order as - * `vectorize_model_parameters_bgm()`. - */ -arma::vec gradient_log_pseudoposterior( - const arma::mat& main_effects, - const arma::mat& pairwise_effects, +std::pair gradient_observed_active( const arma::imat& inclusion_indicator, const arma::imat& observations, const arma::ivec& num_categories, @@ -386,15 +324,10 @@ arma::vec gradient_log_pseudoposterior( const arma::imat& blume_capel_stats, const arma::ivec& baseline_category, const arma::uvec& is_ordinal_variable, - const double main_alpha, - const double main_beta, - const double pairwise_scale, - const arma::imat& pairwise_stats, - const arma::mat& residual_matrix + const arma::imat& pairwise_stats ) { const int num_variables = observations.n_cols; - const int num_persons = observations.n_rows; - const int num_main = count_num_main_effects(num_categories, is_ordinal_variable); + const int num_main = count_num_main_effects(num_categories, is_ordinal_variable); arma::imat index_matrix(num_variables, num_variables, arma::fill::zeros); // Count active pairwise effects + Index map for pairwise parameters @@ -434,38 +367,86 @@ arma::vec gradient_log_pseudoposterior( } } + return {gradient, index_matrix}; +} + + + +/** + * Computes the gradient of the log-pseudoposterior for main and active pairwise parameters. + * + * Gradient components: + * - Observed sufficient statistics (from counts_per_category, blume_capel_stats, pairwise_stats). + * - Minus expected sufficient statistics (computed via probabilities over categories). + * - Plus gradient contributions from priors: + * * Beta priors on main effects. + * * Cauchy priors on active pairwise effects. + * + * Inputs: + * - main_effects: Matrix of main-effect parameters (variables × categories). + * - pairwise_effects: Symmetric matrix of pairwise interaction strengths. + * - inclusion_indicator: Symmetric binary matrix of active pairwise effects. + * - observations: Matrix of categorical observations (persons × variables). + * - num_categories: Number of categories per variable. + * - counts_per_category: Category counts per variable (for ordinal variables). + * - blume_capel_stats: Sufficient statistics for Blume–Capel variables. + * - baseline_category: Reference categories for Blume–Capel variables. + * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). + * - main_alpha, main_beta: Hyperparameters for Beta priors. + * - pairwise_scale: Scale parameter of the Cauchy prior on interactions. + * - pairwise_stats: Sufficient statistics for pairwise effects. + * - residual_matrix: Matrix of residual scores (persons × variables). + * + * Returns: + * - A vector containing the gradient of the log-pseudoposterior with respect to + * all main and active pairwise parameters, in the same order as + * `vectorize_model_parameters_bgm()`. + */ +arma::vec gradient_log_pseudoposterior( + const arma::mat& main_effects, + const arma::mat& pairwise_effects, + const arma::imat& inclusion_indicator, + const arma::imat& observations, + const arma::ivec& num_categories, + const arma::ivec& baseline_category, + const arma::uvec& is_ordinal_variable, + const double main_alpha, + const double main_beta, + const double pairwise_scale, + const arma::mat& residual_matrix, + const arma::imat index_matrix, + const arma::vec grad_obs +) { + const int num_variables = observations.n_cols; + const int num_persons = observations.n_rows; + + // Allocate gradient vector (main + active pairwise only) + arma::vec gradient = grad_obs; + // ---- STEP 2: Expected statistics ---- - offset = 0; + int offset = 0; for (int variable = 0; variable < num_variables; variable++) { const int num_cats = num_categories(variable); arma::vec residual_score = residual_matrix.col(variable); arma::vec bound = num_cats * residual_score; - bound = arma::clamp(bound, 0.0, arma::datum::inf); if (is_ordinal_variable(variable)) { arma::vec main_param = main_effects.row(variable).cols(0, num_cats - 1).t(); - bound += main_param.max(); - - arma::mat exponents(num_persons, num_cats); - for (int cat = 0; cat < num_cats; cat++) { - exponents.col(cat) = main_param(cat) + (cat + 1) * residual_score - bound; - } - - arma::mat probs = ARMA_MY_EXP(exponents); - arma::vec denom = arma::sum(probs, 1) + ARMA_MY_EXP(-bound); - probs.each_col() /= denom; + arma::mat probs = compute_probs_ordinal( + main_param, residual_score, bound, num_cats + ); // main effects for (int cat = 0; cat < num_cats; cat++) { - gradient(offset + cat) -= arma::accu(probs.col(cat)); + gradient(offset + cat) -= arma::accu(probs.col(cat + 1)); } // pairwise effects for (int j = 0; j < num_variables; j++) { if (inclusion_indicator(variable, j) == 0 || variable == j) continue; arma::vec expected_value = arma::zeros(num_persons); - for (int cat = 0; cat < num_cats; cat++) { - expected_value += (cat + 1) * probs.col(cat) % observations.col(j); + for (int cat = 1; cat <= num_cats; cat++) { + expected_value += cat * probs.col(cat) % observations.col(j); } int location = (variable < j) ? index_matrix(variable, j) : index_matrix(j, variable); gradient(location) -= arma::accu(expected_value); @@ -476,33 +457,30 @@ arma::vec gradient_log_pseudoposterior( const double lin_eff = main_effects(variable, 0); const double quad_eff = main_effects(variable, 1); - arma::mat exponents(num_persons, num_cats + 1); - for (int cat = 0; cat <= num_cats; cat++) { - int score = cat; - int centered = score - ref; - double lin = lin_eff * score; - double quad = quad_eff * centered * centered; - exponents.col(cat) = lin + quad + score * residual_score - bound; - } - arma::mat probs = ARMA_MY_EXP(exponents); - arma::vec denom = arma::sum(probs, 1); - probs.each_col() /= denom; + arma::mat probs = compute_probs_blume_capel( + residual_score, lin_eff, quad_eff, ref, num_cats, bound + ); - arma::ivec lin_score = arma::regspace(0, num_cats); - arma::ivec quad_score = arma::square(lin_score - ref); + arma::vec score = arma::regspace(0, num_cats) - double(ref); + arma::vec sq_score = arma::square(score); // main effects - gradient(offset) -= arma::accu(probs * lin_score); - gradient(offset + 1) -= arma::accu(probs * quad_score); + gradient(offset) -= arma::accu(probs * score); + gradient(offset + 1) -= arma::accu(probs * sq_score); // pairwise effects for (int j = 0; j < num_variables; j++) { if (inclusion_indicator(variable, j) == 0 || variable == j) continue; arma::vec expected_value = arma::zeros(num_persons); - for (int cat = 0; cat < num_cats; cat++) { - expected_value += (cat + 1) * probs.col(cat + 1) % observations.col(j); + for (int cat = 0; cat <= num_cats; cat++) { + int s = score(cat); + expected_value += s * probs.col(cat) % observations.col(j); } - int location = (variable < j) ? index_matrix(variable, j) : index_matrix(j, variable); + + int location = (variable < j) + ? index_matrix(variable, j) + : index_matrix(j, variable); + gradient(location) -= arma::accu(expected_value); } offset += 2; @@ -589,48 +567,43 @@ double compute_log_likelihood_ratio_for_variable ( arma::vec interaction = arma::conv_to::from (interacting_score); const int num_persons = residual_matrix.n_rows; - const int num_categories_var = num_categories (variable); + const int num_cats = num_categories (variable); // Compute adjusted linear predictors without the current interaction - arma::vec residual_scores = residual_matrix.col (variable) - interaction * current_state; - - // Stability bound for softmax (scaled by number of categories) - arma::vec bounds = arma::max (residual_scores, arma::zeros (num_persons)) * num_categories_var; + arma::vec residual_score = residual_matrix.col (variable) - interaction * current_state; + arma::vec bounds = residual_score * num_cats; arma::vec denom_current = arma::zeros (num_persons); arma::vec denom_proposed = arma::zeros (num_persons); if (is_ordinal_variable (variable)) { - denom_current += ARMA_MY_EXP(-bounds); - denom_proposed += ARMA_MY_EXP(-bounds); - - for (int category = 0; category < num_categories_var; category++) { - const double main = main_effects(variable, category); - const int score = category + 1; + arma::vec main_param = main_effects.row(variable).cols(0, num_cats - 1).t(); - for (int person = 0; person < num_persons; person++) { - const double base = main + score * residual_scores[person] - bounds[person]; + // ---- main change: use safe helper ---- + denom_current += compute_denom_ordinal( + residual_score + interaction * current_state, main_param, bounds + ); + denom_proposed += compute_denom_ordinal( + residual_score + interaction * proposed_state, main_param, bounds + ); - const double exp_current = MY_EXP(base + score * interaction[person] * current_state); - const double exp_proposed = MY_EXP(base + score * interaction[person] * proposed_state); - - denom_current[person] += exp_current; - denom_proposed[person] += exp_proposed; - } - } } else { // Binary or categorical variable: linear + quadratic score const int ref_cat = baseline_category (variable); - for (int category = 0; category <= num_categories_var; category++) { - int centered = category - ref_cat; - double lin_term = main_effects (variable, 0) * category; - double quad_term = main_effects (variable, 1) * centered * centered; - arma::vec exponent = lin_term + quad_term + category * residual_scores - bounds; + denom_current = compute_denom_blume_capel( + residual_score + interaction * current_state, main_effects (variable, 0), + main_effects (variable, 1), ref_cat, num_cats, bounds + ); + double log_ratio = arma::accu(ARMA_MY_LOG (denom_current) + bounds); - denom_current += ARMA_MY_EXP (exponent + category * interaction * current_state); - denom_proposed += ARMA_MY_EXP (exponent + category * interaction * proposed_state); - } + denom_proposed = compute_denom_blume_capel( + residual_score + interaction * proposed_state, main_effects (variable, 0), + main_effects (variable, 1), ref_cat, num_cats, bounds + ); + log_ratio -= arma::accu(ARMA_MY_LOG (denom_proposed) + bounds); + + return log_ratio; } // Accumulated log-likelihood difference across persons diff --git a/src/bgm_logp_and_grad.h b/src/bgm/bgm_logp_and_grad.h similarity index 88% rename from src/bgm_logp_and_grad.h rename to src/bgm/bgm_logp_and_grad.h index 6e72812c..2affd1a5 100644 --- a/src/bgm_logp_and_grad.h +++ b/src/bgm/bgm_logp_and_grad.h @@ -51,21 +51,31 @@ double log_pseudoposterior ( const arma::mat& residual_matrix ); +std::pair gradient_observed_active( + const arma::imat& inclusion_indicator, + const arma::imat& observations, + const arma::ivec& num_categories, + const arma::imat& counts_per_category, + const arma::imat& blume_capel_stats, + const arma::ivec& baseline_category, + const arma::uvec& is_ordinal_variable, + const arma::imat& pairwise_stats +); + arma::vec gradient_log_pseudoposterior( const arma::mat& main_effects, const arma::mat& pairwise_effects, const arma::imat& inclusion_indicator, const arma::imat& observations, const arma::ivec& num_categories, - const arma::imat& counts_per_category, - const arma::imat& blume_capel_stats, const arma::ivec& baseline_category, const arma::uvec& is_ordinal_variable, const double main_alpha, const double main_beta, const double pairwise_scale, - const arma::imat& pairwise_stats, - const arma::mat& residual_matrix + const arma::mat& residual_matrix, + const arma::imat index_matrix, + const arma::vec grad_obs ); // Pseudolikelihood ratio for a single variable diff --git a/src/bgm/bgm_output.h b/src/bgm/bgm_output.h new file mode 100644 index 00000000..0768945c --- /dev/null +++ b/src/bgm/bgm_output.h @@ -0,0 +1,21 @@ +#pragma once +#include + +struct bgmOutput { + // required + arma::mat main_samples; + arma::mat pairwise_samples; + + // optional (only if edge_selection) + arma::imat indicator_samples; + arma::imat allocation_samples; + + // optional (only if NUTS) + arma::ivec treedepth_samples; + arma::ivec divergent_samples; + arma::vec energy_samples; + + // metadata + int chain_id = -1; + bool userInterrupt = false; +}; diff --git a/src/bgm_sampler.cpp b/src/bgm/bgm_sampler.cpp similarity index 95% rename from src/bgm_sampler.cpp rename to src/bgm/bgm_sampler.cpp index b1397f24..680386bc 100644 --- a/src/bgm_sampler.cpp +++ b/src/bgm/bgm_sampler.cpp @@ -1,18 +1,19 @@ #include -#include "bgm_helper.h" -#include "bgm_logp_and_grad.h" -#include "bgm_sampler.h" -#include "common_helpers.h" -#include "mcmc_adaptation.h" -#include "mcmc_hmc.h" -#include "mcmc_leapfrog.h" -#include "mcmc_nuts.h" -#include "mcmc_rwm.h" -#include "mcmc_utils.h" -#include "sbm_edge_prior.h" -#include "rng_utils.h" -#include "progress_manager.h" -#include "chainResults.h" +#include "bgm/bgm_helper.h" +#include "bgm/bgm_logp_and_grad.h" +#include "bgm/bgm_sampler.h" +#include "bgm/bgm_output.h" +#include "mcmc/mcmc_adaptation.h" +#include "mcmc/mcmc_hmc.h" +#include "mcmc/mcmc_leapfrog.h" +#include "mcmc/mcmc_nuts.h" +#include "mcmc/mcmc_rwm.h" +#include "mcmc/mcmc_utils.h" +#include "priors/sbm_edge_prior.h" +#include "sbm_edge_prior_interface.h" +#include "rng/rng_utils.h" +#include "utils/common_helpers.h" +#include "utils/progress_manager.h" @@ -100,18 +101,19 @@ void impute_missing_bgm ( // Compute probabilities for Blume-Capel variable const int ref = baseline_category (variable); - cumsum = MY_EXP (main_effects (variable, 1) * ref * ref); + cumsum = MY_EXP ( + main_effects (variable, 0) * ref + main_effects (variable, 1) * ref * ref + ); category_probabilities[0] = cumsum; - for (int cat = 0; cat < num_cats; cat++) { - const int score = cat + 1; - const int centered = score - ref; + for (int cat = 0; cat <= num_cats; cat++) { + const int score = cat - ref; const double exponent = main_effects (variable, 0) * score + - main_effects (variable, 1) * centered * centered + + main_effects (variable, 1) * score * score + score * residual_score; cumsum += MY_EXP (exponent); - category_probabilities[score] = cumsum; + category_probabilities[cat] = cumsum; } } @@ -122,7 +124,9 @@ void impute_missing_bgm ( sampled_score++; } - const int new_value = sampled_score; + int new_value = sampled_score; + if(!is_ordinal) + new_value -= baseline_category (variable); const int old_value = observations(person, variable); if (new_value != old_value) { @@ -133,11 +137,8 @@ void impute_missing_bgm ( counts_per_category(old_value, variable)--; counts_per_category(new_value, variable)++; } else { - const int ref = baseline_category(variable); const int delta = new_value - old_value; - const int delta_sq = - (new_value - ref) * (new_value - ref) - - (old_value - ref) * (old_value - ref); + const int delta_sq = new_value * new_value - old_value * old_value; blume_capel_stats(0, variable) += delta; blume_capel_stats(1, variable) += delta_sq; @@ -212,27 +213,35 @@ double find_initial_stepsize_bgm( num_categories, is_ordinal_variable ); + arma::vec grad_obs; + arma::imat index_matrix; + + std::tie(grad_obs, index_matrix) = gradient_observed_active( + inclusion_indicator, observations, num_categories, counts_per_category, + blume_capel_stats, baseline_category, is_ordinal_variable, pairwise_stats + ); + arma::mat current_main = main_effects; arma::mat current_pair = pairwise_effects; - auto log_post = [&](const arma::vec& theta_vec) { - unvectorize_model_parameters_bgm(theta_vec, current_main, current_pair, - inclusion_indicator, - num_categories, is_ordinal_variable); + auto grad = [&](const arma::vec& theta_vec) { + unvectorize_model_parameters_bgm(theta_vec, current_main, current_pair, inclusion_indicator, + num_categories, is_ordinal_variable); arma::mat rm = observations * current_pair; - return log_pseudoposterior( - current_main, current_pair, inclusion_indicator, observations, - num_categories, counts_per_category, blume_capel_stats, - baseline_category, is_ordinal_variable, main_alpha, main_beta, - pairwise_scale, pairwise_stats, rm + + return gradient_log_pseudoposterior ( + current_main, current_pair, inclusion_indicator, observations, + num_categories, baseline_category, is_ordinal_variable, main_alpha, + main_beta, pairwise_scale, rm, index_matrix, grad_obs ); }; - auto grad = [&](const arma::vec& theta_vec) { - unvectorize_model_parameters_bgm(theta_vec, current_main, current_pair, inclusion_indicator, - num_categories, is_ordinal_variable); + auto log_post = [&](const arma::vec& theta_vec) { + unvectorize_model_parameters_bgm(theta_vec, current_main, current_pair, + inclusion_indicator, + num_categories, is_ordinal_variable); arma::mat rm = observations * current_pair; - return gradient_log_pseudoposterior( + return log_pseudoposterior( current_main, current_pair, inclusion_indicator, observations, num_categories, counts_per_category, blume_capel_stats, baseline_category, is_ordinal_variable, main_alpha, main_beta, @@ -521,6 +530,14 @@ void update_hmc_bgm( num_categories, is_ordinal_variable ); + arma::vec grad_obs; + arma::imat index_matrix; + + std::tie(grad_obs, index_matrix) = gradient_observed_active( + inclusion_indicator, observations, num_categories, counts_per_category, + blume_capel_stats, baseline_category, is_ordinal_variable, pairwise_stats + ); + arma::mat current_main = main_effects; arma::mat current_pair = pairwise_effects; @@ -531,9 +548,8 @@ void update_hmc_bgm( return gradient_log_pseudoposterior ( current_main, current_pair, inclusion_indicator, observations, - num_categories, counts_per_category, blume_capel_stats, - baseline_category, is_ordinal_variable, main_alpha, - main_beta, pairwise_scale, pairwise_stats, rm + num_categories, baseline_category, is_ordinal_variable, main_alpha, + main_beta, pairwise_scale, rm, index_matrix, grad_obs ); }; @@ -641,20 +657,26 @@ SamplerResult update_nuts_bgm( num_categories, is_ordinal_variable ); + arma::vec grad_obs; + arma::imat index_matrix; + + std::tie(grad_obs, index_matrix) = gradient_observed_active( + inclusion_indicator, observations, num_categories, counts_per_category, + blume_capel_stats, baseline_category, is_ordinal_variable, pairwise_stats + ); + arma::mat current_main = main_effects; arma::mat current_pair = pairwise_effects; auto grad = [&](const arma::vec& theta_vec) { - unvectorize_model_parameters_bgm(theta_vec, current_main, current_pair, - inclusion_indicator, num_categories, - is_ordinal_variable); + unvectorize_model_parameters_bgm(theta_vec, current_main, current_pair, inclusion_indicator, + num_categories, is_ordinal_variable); arma::mat rm = observations * current_pair; - return gradient_log_pseudoposterior( - current_main, current_pair, inclusion_indicator, observations, - num_categories, counts_per_category, blume_capel_stats, - baseline_category, is_ordinal_variable, main_alpha, - main_beta, pairwise_scale, pairwise_stats, rm + return gradient_log_pseudoposterior ( + current_main, current_pair, inclusion_indicator, observations, + num_categories, baseline_category, is_ordinal_variable, main_alpha, + main_beta, pairwise_scale, rm, index_matrix, grad_obs ); }; @@ -1165,8 +1187,8 @@ void gibbs_update_step_bgm ( * - Parallel execution across chains is handled by `run_bgm_parallel()`; * this function is for one chain only. */ -void run_gibbs_sampler_bgm( - ChainResult& chain_result, +bgmOutput run_gibbs_sampler_bgm( + int chain_id, arma::imat observations, const arma::ivec& num_categories, const double pairwise_scale, @@ -1200,9 +1222,6 @@ void run_gibbs_sampler_bgm( SafeRNG& rng, ProgressManager& pm ) { - - int chain_id = chain_result.chain_id; - // --- Setup: dimensions and storage structures const int num_variables = observations.n_cols; const int num_persons = observations.n_rows; @@ -1411,22 +1430,22 @@ void run_gibbs_sampler_bgm( } } + bgmOutput chain_result; + chain_result.chain_id = chain_id; chain_result.userInterrupt = userInterrupt; - chain_result.main_effect_samples = main_effect_samples; - chain_result.pairwise_effect_samples = pairwise_effect_samples; - + chain_result.main_samples = main_effect_samples; + chain_result.pairwise_samples = pairwise_effect_samples; if (update_method == nuts) { chain_result.treedepth_samples = treedepth_samples; chain_result.divergent_samples = divergent_samples; chain_result.energy_samples = energy_samples; } - if (edge_selection) { chain_result.indicator_samples = indicator_samples; - if (edge_prior == Stochastic_Block) chain_result.allocation_samples = allocation_samples; } + return chain_result; } \ No newline at end of file diff --git a/src/bgm_sampler.h b/src/bgm/bgm_sampler.h similarity index 91% rename from src/bgm_sampler.h rename to src/bgm/bgm_sampler.h index 2fcbf357..3e280ee1 100644 --- a/src/bgm_sampler.h +++ b/src/bgm/bgm_sampler.h @@ -1,13 +1,14 @@ #pragma once #include -#include "common_helpers.h" +#include "utils/common_helpers.h" +#include "bgm/bgm_output.h" + // forward declaration struct SafeRNG; class ProgressManager; -struct ChainResult; -void run_gibbs_sampler_bgm( - ChainResult& chain_result, +bgmOutput run_gibbs_sampler_bgm( + int chain_id, arma::imat observations, const arma::ivec& num_categories, const double pairwise_scale, diff --git a/src/bgmCompare_helper.cpp b/src/bgmCompare/bgmCompare_helper.cpp similarity index 99% rename from src/bgmCompare_helper.cpp rename to src/bgmCompare/bgmCompare_helper.cpp index f73fc5fb..10706afa 100644 --- a/src/bgmCompare_helper.cpp +++ b/src/bgmCompare/bgmCompare_helper.cpp @@ -1,7 +1,7 @@ #include #include -#include "bgmCompare_helper.h" -#include "common_helpers.h" +#include "bgmCompare/bgmCompare_helper.h" +#include "utils/common_helpers.h" diff --git a/src/bgmCompare_helper.h b/src/bgmCompare/bgmCompare_helper.h similarity index 99% rename from src/bgmCompare_helper.h rename to src/bgmCompare/bgmCompare_helper.h index 491609fd..05759b3c 100644 --- a/src/bgmCompare_helper.h +++ b/src/bgmCompare/bgmCompare_helper.h @@ -1,7 +1,7 @@ #pragma once #include -#include "rng_utils.h" +#include "rng/rng_utils.h" diff --git a/src/bgmCompare_logp_and_grad.cpp b/src/bgmCompare/bgmCompare_logp_and_grad.cpp similarity index 89% rename from src/bgmCompare_logp_and_grad.cpp rename to src/bgmCompare/bgmCompare_logp_and_grad.cpp index 789c6af1..5b4c15b6 100644 --- a/src/bgmCompare_logp_and_grad.cpp +++ b/src/bgmCompare/bgmCompare_logp_and_grad.cpp @@ -1,9 +1,10 @@ #include -#include "bgmCompare_helper.h" -#include "bgmCompare_logp_and_grad.h" +#include "bgmCompare/bgmCompare_helper.h" +#include "bgmCompare/bgmCompare_logp_and_grad.h" #include -#include "explog_switch.h" -#include "common_helpers.h" +#include "math/explog_switch.h" +#include "utils/common_helpers.h" +#include "utils/variable_helpers.h" @@ -91,7 +92,7 @@ double log_pseudoposterior( const arma::vec proj_g = projection.row(group).t(); // length = num_groups-1 // ---- build group-specific main & pairwise effects ---- - for (int v = 0; v < num_variables; ++v) { + for (int v = 0; v < num_variables; v++) { arma::vec me = compute_group_main_effects( v, num_groups, main_effects, main_effect_indices, proj_g ); @@ -100,7 +101,7 @@ double log_pseudoposterior( main_group(v, arma::span(0, me.n_elem - 1)) = me.t(); // upper triangle incl. base value; mirror to keep symmetry - for (int u = v; u < num_variables; ++u) { // Combines with loop over v + for (int u = v; u < num_variables; u++) { // Combines with loop over v if(u == v) continue; double w = compute_group_pairwise_effects( v, u, num_groups, pairwise_effects, pairwise_effect_indices, @@ -114,7 +115,7 @@ double log_pseudoposterior( const int num_cats = num_categories(v); if (is_ordinal_variable(v)) { // use group-specific main_effects - for (int c = 0; c < num_cats; ++c) { + for (int c = 0; c < num_cats; c++) { const double val = main_group(v, c); log_pp += static_cast(counts_per_category(c, v)) * val; } @@ -141,32 +142,25 @@ double log_pseudoposterior( // bound to stabilize exp; use group-specific params consistently arma::vec bound = num_cats * rest_score; - bound = arma::clamp(bound, 0.0, arma::datum::inf); - arma::vec denom(rest_score.n_elem, arma::fill::zeros); if (is_ordinal_variable(v)) { - // base term exp(-bound) - denom = ARMA_MY_EXP(-bound); - // main_effects from main_group - for (int c = 0; c < num_cats; ++c) { - const double th = main_group(v, c); - const arma::vec exponent = th + (c + 1) * rest_score - bound; - denom += ARMA_MY_EXP(exponent); - } + arma::vec main_eff = main_group.row(v).cols(0, num_cats - 1).t(); + denom = compute_denom_ordinal( + rest_score, main_eff, bound + ); } else { // linear/quadratic main effects from main_group const double lin_effect = main_group(v, 0); const double quad_effect = main_group(v, 1); const int ref = baseline_category(v); - for (int c = 0; c <= num_cats; ++c) { - const int centered = c - ref; - const double quad = quad_effect * centered * centered; - const double lin = lin_effect * c; - const arma::vec exponent = lin + quad + c * rest_score - bound; - denom += ARMA_MY_EXP(exponent); - } + + denom = compute_denom_blume_capel( + rest_score, lin_effect, quad_effect, ref, num_cats, + /*updated in place:*/bound + ); } + // - sum_i [ bound_i + log denom_i ] log_pp -= arma::accu(bound + ARMA_MY_LOG(denom)); } @@ -178,27 +172,27 @@ double log_pseudoposterior( }; // Main effects prior - for (int v = 0; v < num_variables; ++v) { + for (int v = 0; v < num_variables; v++) { const int row0 = main_effect_indices(v, 0); const int row1 = main_effect_indices(v, 1); - for (int r = row0; r <= row1; ++r) { + for (int r = row0; r <= row1; r++) { log_pp += log_beta_prior(main_effects(r, 0)); if (inclusion_indicator(v, v) == 0) continue; - for (int eff = 1; eff < num_groups; ++eff) { + for (int eff = 1; eff < num_groups; eff++) { log_pp += R::dcauchy(main_effects(r, eff), 0.0, difference_scale, true); } } } // Pairwise effects prior - for (int v1 = 0; v1 < num_variables - 1; ++v1) { - for (int v2 = v1 + 1; v2 < num_variables; ++v2) { + for (int v1 = 0; v1 < num_variables - 1; v1++) { + for (int v2 = v1 + 1; v2 < num_variables; v2++) { const int idx = pairwise_effect_indices(v1, v2); log_pp += R::dcauchy(pairwise_effects(idx, 0), 0.0, interaction_scale, true); if (inclusion_indicator(v1, v2) == 0) continue; - for (int eff = 1; eff < num_groups; ++eff) { + for (int eff = 1; eff < num_groups; eff++) { log_pp += R::dcauchy(pairwise_effects(idx, eff), 0.0, difference_scale, true); } } @@ -344,19 +338,19 @@ arma::vec gradient_observed_active( // ------------------------------- // Observed sufficient statistics // ------------------------------- - for (int g = 0; g < num_groups; ++g) { + for (int g = 0; g < num_groups; g++) { // list access arma::imat counts_per_category = counts_per_category_group[g]; arma::imat blume_capel_stats = blume_capel_stats_group[g]; const arma::vec proj_g = projection.row(g).t(); // length = num_groups-1 // Main effects - for (int v = 0; v < num_variables; ++v) { - const int base = main_effect_indices(v, 0); + for (int v = 0; v < num_variables; v++) { + const int base = main_effect_indices(v, 0); const int num_cats = num_categories(v); if (is_ordinal_variable(v)) { - for (int c = 0; c < num_cats; ++c) { + for (int c = 0; c < num_cats; c++) { const int count = counts_per_category(c, v); // overall off = main_index(base + c, 0); @@ -364,7 +358,7 @@ arma::vec gradient_observed_active( // diffs if(inclusion_indicator(v, v) != 0) { - for (int k = 1; k < num_groups; ++k) { + for (int k = 1; k < num_groups; k++) { off = main_index(base + c, k); grad_obs(off) += count * proj_g(k-1); } @@ -383,7 +377,7 @@ arma::vec gradient_observed_active( // diffs if(inclusion_indicator(v, v) != 0) { - for (int k = 1; k < num_groups; ++k) { + for (int k = 1; k < num_groups; k++) { off = main_index(base, k); grad_obs(off) += bc_0 * proj_g(k-1); @@ -396,8 +390,8 @@ arma::vec gradient_observed_active( // Pairwise (observed) arma::mat pairwise_stats = pairwise_stats_group[g]; - for (int v1 = 0; v1 < num_variables - 1; ++v1) { - for (int v2 = v1 + 1; v2 < num_variables; ++v2) { + for (int v1 = 0; v1 < num_variables - 1; v1++) { + for (int v2 = v1 + 1; v2 < num_variables; v2++) { const int row = pairwise_effect_indices(v1, v2); const double pw_stats = 2.0 * pairwise_stats(v1, v2); @@ -405,7 +399,7 @@ arma::vec gradient_observed_active( grad_obs(off) += pw_stats; // upper tri counted once if(inclusion_indicator(v1, v2) != 0){ - for (int k = 1; k < num_groups; ++k) { + for (int k = 1; k < num_groups; k++) { off = pair_index(row, k); grad_obs(off) += pw_stats * proj_g(k-1); } @@ -548,39 +542,30 @@ arma::vec gradient( const int num_group_obs = obs.n_rows; for (int v = 0; v < num_variables; ++v) { - const int K = num_categories(v); + const int K = num_categories(v); const int ref = baseline_category(v); arma::vec rest_score = residual_matrix.col(v); - arma::vec bound = K * rest_score; - bound.clamp(0.0, arma::datum::inf); - - arma::mat exponents(num_group_obs, K + 1, arma::fill::none); + arma::vec bound = K * rest_score; + arma::mat probs; if (is_ordinal_variable(v)) { - exponents.col(0) = -bound; - for (int j = 0; j < K; ++j) { - exponents.col(j + 1) = main_group(v, j) + (j + 1) * rest_score - bound; - } + arma::vec main_param = main_group.row(v).cols(0, K - 1).t(); + probs = compute_probs_ordinal( + main_param, rest_score, bound, K + ); } else { - const double lin_effect = main_group(v, 0); + const double lin_effect = main_group(v, 0); const double quad_effect = main_group(v, 1); - for (int s = 0; s <= K; ++s) { - const int centered = s - ref; - const double lin = lin_effect * s; - const double quad = quad_effect * centered * centered; - exponents.col(s) = lin + quad + s * rest_score - bound; - } + probs = compute_probs_blume_capel( + rest_score, lin_effect, quad_effect, ref, K, bound + ); } - arma::mat probs = ARMA_MY_EXP(exponents); - arma::vec denom = arma::sum(probs, 1); // base term - probs.each_col() /= denom; - // ---- MAIN expected ---- const int base = main_effect_indices(v, 0); if (is_ordinal_variable(v)) { - for (int s = 1; s <= K; ++s) { + for (int s = 1; s <= K; s++) { const int j = s - 1; double sum_col_s = arma::accu(probs.col(s)); @@ -588,14 +573,14 @@ arma::vec gradient( grad(off) -= sum_col_s; if (inclusion_indicator(v, v) == 0) continue; - for (int k = 1; k < num_groups; ++k) { + for (int k = 1; k < num_groups; k++) { off = main_index(base + j, k); grad(off) -= proj_g(k - 1) * sum_col_s; } } } else { - arma::vec lin_score = arma::regspace(0, K); // length K+1 - arma::vec quad_score = arma::square(lin_score - ref); + arma::vec lin_score = arma::regspace(0 - ref, K - ref); // length K+1 + arma::vec quad_score = arma::square(lin_score); double sum_lin = arma::accu(probs * lin_score); double sum_quad = arma::accu(probs * quad_score); @@ -606,7 +591,7 @@ arma::vec gradient( grad(off) -= sum_quad; if (inclusion_indicator(v, v) == 0) continue; - for (int k = 1; k < num_groups; ++k) { + for (int k = 1; k < num_groups; k++) { off = main_index(base, k); grad(off) -= proj_g(k - 1) * sum_lin; off = main_index(base + 1, k); @@ -615,12 +600,19 @@ arma::vec gradient( } // ---- PAIRWISE expected ---- - for (int v2 = 0; v2 < num_variables; ++v2) { + for (int v2 = 0; v2 < num_variables; v2++) { if (v == v2) continue; arma::vec expected_value(num_group_obs, arma::fill::zeros); - for (int s = 1; s <= K; ++s) { - expected_value += s * probs.col(s) % obs.col(v2); + if (is_ordinal_variable(v)) { + for (int s = 1; s <= K; ++s) { + expected_value += s * probs.col(s) % obs.col(v2); + } + } else { + for (int s = 0; s <= K; s++) { + int score = s - ref; + expected_value += score * probs.col(s) % obs.col(v2); + } } double sum_expectation = arma::accu(expected_value); @@ -631,7 +623,7 @@ arma::vec gradient( grad(off) -= sum_expectation; if (inclusion_indicator(v, v2) == 0) continue; - for (int k = 1; k < num_groups; ++k) { + for (int k = 1; k < num_groups; k++) { off = pair_index(row, k); grad(off) -= proj_g(k - 1) * sum_expectation; @@ -782,7 +774,7 @@ double log_pseudoposterior_main_component( int variable, int category, // for ordinal variables only int par, // for Blume-Capel variables only - int h // Overall = 0, differences are 1, .... + int h // Overall = 0, differences are 1,2,... ) { if(h > 0 && inclusion_indicator(variable, variable) == 0) { return 0.0; // No contribution if differences not included @@ -807,7 +799,7 @@ double log_pseudoposterior_main_component( variable, num_groups, main_effects, main_effect_indices, proj_g ); - // store into row v (padded with zeros if variable has < max_num_categories params) + // store into row v main_group(variable, arma::span(0, me.n_elem - 1)) = me.t(); // upper triangle incl. base value; mirror to keep symmetry @@ -824,8 +816,7 @@ double log_pseudoposterior_main_component( // ---- data contribution pseudolikelihood (linear terms) ---- if (is_ordinal_variable(variable)) { const double val = main_group(variable, category); - log_pp += static_cast(counts_per_category(category, variable)) * - val; + log_pp += static_cast(counts_per_category(category, variable)) * val; } else { log_pp += static_cast(blume_capel_stats(par, variable)) * main_group(variable, par); @@ -842,31 +833,25 @@ double log_pseudoposterior_main_component( // bound to stabilize exp; use group-specific params consistently arma::vec bound = num_cats * rest_score; - bound = arma::clamp(bound, 0.0, arma::datum::inf); - arma::vec denom(rest_score.n_elem, arma::fill::zeros); + if (is_ordinal_variable(variable)) { - // base term exp(-bound) - denom = ARMA_MY_EXP(-bound); - // main_effects from main_group - for (int cat = 0; cat < num_cats; cat++) { - const double th = main_group(variable, cat); - const arma::vec exponent = th + (cat + 1) * rest_score - bound; - denom += ARMA_MY_EXP(exponent); - } + arma::vec main_eff = main_group.row(variable).cols(0, num_cats - 1).t(); + denom = compute_denom_ordinal( + rest_score, main_eff, bound + ); } else { // linear/quadratic main effects from main_group const double lin_effect = main_group(variable, 0); const double quad_effect = main_group(variable, 1); const int ref = baseline_category(variable); - for (int cat = 0; cat <= num_cats; cat++) { - const int centered = cat - ref; - const double quad = quad_effect * centered * centered; - const double lin = lin_effect * cat; - const arma::vec exponent = lin + quad + cat * rest_score - bound; - denom += ARMA_MY_EXP(exponent); - } + + denom = compute_denom_blume_capel( + rest_score, lin_effect, quad_effect, ref, num_cats, + /*updated in place:*/bound + ); } + // - sum_i [ bound_i + log denom_i ] log_pp -= arma::accu(bound + ARMA_MY_LOG(denom)); } @@ -1025,32 +1010,25 @@ double log_pseudoposterior_pair_component( // bound to stabilize exp; use group-specific params consistently arma::vec bound = num_cats * rest_score; - bound = arma::clamp(bound, 0.0, arma::datum::inf); - arma::vec denom(rest_score.n_elem, arma::fill::zeros); if (is_ordinal_variable(v)) { - // base term exp(-bound) - denom = ARMA_MY_EXP(-bound); - // main_effects from main_group - for (int c = 0; c < num_cats; ++c) { - const double th = main_group(v, c); - const arma::vec exponent = th + (c + 1) * rest_score - bound; - denom += ARMA_MY_EXP(exponent); - } + arma::vec main_eff = main_group.row(v).cols(0, num_cats - 1).t(); + denom = compute_denom_ordinal( + rest_score, main_eff, bound + ); } else { // linear/quadratic main effects from main_group const double lin_effect = main_group(v, 0); const double quad_effect = main_group(v, 1); const int ref = baseline_category(v); - for (int c = 0; c <= num_cats; ++c) { - const int centered = c - ref; - const double quad = quad_effect * centered * centered; - const double lin = lin_effect * c; - const arma::vec exponent = lin + quad + c * rest_score - bound; - denom += ARMA_MY_EXP(exponent); - } + + denom = compute_denom_blume_capel( + rest_score, lin_effect, quad_effect, ref, num_cats, + /*updated in place:*/bound + ); } + // - sum_i [ bound_i + log denom_i ] log_pp -= arma::accu(bound + ARMA_MY_LOG(denom)); } @@ -1173,40 +1151,31 @@ double log_ratio_pseudolikelihood_constant_variable( arma::vec denom_current(rest_current.n_elem, arma::fill::zeros); arma::vec denom_proposed(rest_proposed.n_elem, arma::fill::zeros); - if (is_ordinal_variable(variable)) { - // regular ordinal/binary - bound_current = num_cats * arma::clamp(rest_current, 0.0, arma::datum::inf); - bound_proposed = num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf); - - denom_current = ARMA_MY_EXP(-bound_current); - denom_proposed = ARMA_MY_EXP(-bound_proposed); + if (is_ordinal_variable (variable)) { + bound_current = rest_current * num_cats; + bound_proposed = rest_proposed * num_cats; - for (int c = 0; c < num_cats; ++c) { - denom_current += ARMA_MY_EXP(main_current(c) + (c + 1) * rest_current - bound_current); - denom_proposed += ARMA_MY_EXP(main_proposed(c) + (c + 1) * rest_proposed - bound_proposed); - } + denom_current += compute_denom_ordinal( + rest_current, main_current, bound_current + ); + denom_proposed += compute_denom_ordinal( + rest_proposed, main_proposed, bound_proposed + ); } else { - // Blume-Capel: linear + quadratic - const int ref = baseline_category(variable); - - arma::vec const_current(num_cats + 1, arma::fill::zeros); - arma::vec const_proposed(num_cats + 1, arma::fill::zeros); - for (int s = 0; s <= num_cats; ++s) { - const int centered = s - ref; - const_current(s) = main_current(0) * s + main_current(1) * centered * centered; - const_proposed(s) = main_proposed(0) * s + main_proposed(1) * centered * centered; - } - - double lbound = std::max(const_current.max(), const_proposed.max()); - if (lbound < 0.0) lbound = 0.0; - - bound_current = lbound + num_cats * arma::clamp(rest_current, 0.0, arma::datum::inf); - bound_proposed = lbound + num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf); + // Binary or categorical variable: linear + quadratic score + const int ref_cat = baseline_category (variable); + bound_current = rest_current * num_cats; + bound_proposed = rest_proposed * num_cats; + + denom_current = compute_denom_blume_capel( + rest_current, main_current (0), main_current (1), + ref_cat, num_cats, /*Updated in place:*/bound_current + ); - for (int s = 0; s <= num_cats; ++s) { - denom_current += ARMA_MY_EXP(const_current(s) + s * rest_current - bound_current); - denom_proposed += ARMA_MY_EXP(const_proposed(s) + s * rest_proposed - bound_proposed); - } + denom_proposed = compute_denom_blume_capel( + rest_proposed, main_proposed (0), main_proposed (1), + ref_cat, num_cats, /*Updated in place:*/bound_proposed + ); } // --- accumulate contribution --- diff --git a/src/bgmCompare_logp_and_grad.h b/src/bgmCompare/bgmCompare_logp_and_grad.h similarity index 100% rename from src/bgmCompare_logp_and_grad.h rename to src/bgmCompare/bgmCompare_logp_and_grad.h diff --git a/src/sampler_output.h b/src/bgmCompare/bgmCompare_output.h similarity index 97% rename from src/sampler_output.h rename to src/bgmCompare/bgmCompare_output.h index 8f899d00..148a1984 100644 --- a/src/sampler_output.h +++ b/src/bgmCompare/bgmCompare_output.h @@ -21,13 +21,15 @@ * - chain_id: Identifier of the chain. * - has_indicator: True if indicator samples are stored. */ -struct SamplerOutput { +struct bgmCompareOutput { arma::mat main_samples; arma::mat pairwise_samples; arma::imat indicator_samples; + arma::ivec treedepth_samples; arma::ivec divergent_samples; arma::vec energy_samples; + int chain_id; bool has_indicator; bool userInterrupt; diff --git a/src/bgmCompare_sampler.cpp b/src/bgmCompare/bgmCompare_sampler.cpp similarity index 97% rename from src/bgmCompare_sampler.cpp rename to src/bgmCompare/bgmCompare_sampler.cpp index cf6b91d9..cd955e32 100644 --- a/src/bgmCompare_sampler.cpp +++ b/src/bgmCompare/bgmCompare_sampler.cpp @@ -1,20 +1,19 @@ #include -#include "bgmCompare_helper.h" -#include "bgmCompare_logp_and_grad.h" -#include "bgmCompare_sampler.h" -#include "common_helpers.h" -#include "mcmc_adaptation.h" -#include "mcmc_hmc.h" -#include "mcmc_leapfrog.h" -#include "mcmc_nuts.h" -#include "mcmc_rwm.h" -#include "mcmc_utils.h" -#include "rng_utils.h" -#include "sampler_output.h" -#include "explog_switch.h" +#include "bgmCompare/bgmCompare_helper.h" +#include "bgmCompare/bgmCompare_logp_and_grad.h" +#include "bgmCompare/bgmCompare_sampler.h" +#include "bgmCompare/bgmCompare_output.h" +#include "mcmc/mcmc_adaptation.h" +#include "mcmc/mcmc_hmc.h" +#include "mcmc/mcmc_leapfrog.h" +#include "mcmc/mcmc_nuts.h" +#include "mcmc/mcmc_rwm.h" +#include "mcmc/mcmc_utils.h" +#include "rng/rng_utils.h" +#include "math/explog_switch.h" #include -#include "progress_manager.h" - +#include "utils/progress_manager.h" +#include "utils/common_helpers.h" @@ -89,7 +88,7 @@ void impute_missing_bgmcompare( arma::vec category_response_probabilities(max_num_categories + 1); double exponent, cumsum, u; - int score, person, variable, new_observation, old_observation, group; + int score, person, variable, new_value, old_value, group; //Impute missing data for(int missing = 0; missing < num_missings; missing++) { @@ -132,12 +131,12 @@ void impute_missing_bgmcompare( } else { // For Blume-Capel variables cumsum = 0.0; + const int ref = baseline_category[variable]; for(int category = 0; category <= num_categories(variable); category++) { - exponent = group_main_effects[0] * category; - exponent += group_main_effects[1] * - (category - baseline_category[variable]) * - (category - baseline_category[variable]); - exponent += category * rest_score; + score = category - ref; + exponent = group_main_effects[0] * score; + exponent += group_main_effects[1] * score * score; + exponent += rest_score * score; cumsum += MY_EXP(exponent); category_response_probabilities[category] = cumsum; } @@ -149,31 +148,30 @@ void impute_missing_bgmcompare( while (u > category_response_probabilities[score]) { score++; } - new_observation = score; - old_observation = observations(person, variable); - if(old_observation != new_observation) { + new_value = score; + if(!is_ordinal_variable[variable]) + new_value -= baseline_category[variable]; + old_value = observations(person, variable); + + if(old_value != new_value) { // Update raw observations - observations(person, variable) = new_observation; + observations(person, variable) = new_value; // Update sufficient statistics for main effects if(is_ordinal_variable[variable] == true) { arma::imat counts_per_category_group = counts_per_category[group]; - if(old_observation > 0) - counts_per_category_group(old_observation-1, variable)--; - if(new_observation > 0) - counts_per_category_group(new_observation-1, variable)++; + if(old_value > 0) + counts_per_category_group(old_value-1, variable)--; + if(new_value > 0) + counts_per_category_group(new_value-1, variable)++; counts_per_category[group] = counts_per_category_group; } else { arma::imat blume_capel_stats_group = blume_capel_stats[group]; - blume_capel_stats_group(0, variable) -= old_observation; - blume_capel_stats_group(0, variable) += new_observation; - blume_capel_stats_group(1, variable) -= - (old_observation - baseline_category[variable]) * - (old_observation - baseline_category[variable]); - blume_capel_stats_group(1, variable) += - (new_observation - baseline_category[variable]) * - (new_observation - baseline_category[variable]); + blume_capel_stats_group(0, variable) -= old_value; + blume_capel_stats_group(0, variable) += new_value; + blume_capel_stats_group(1, variable) -= old_value * old_value; + blume_capel_stats_group(1, variable) += new_value * new_value; blume_capel_stats[group] = blume_capel_stats_group; } @@ -1566,7 +1564,7 @@ void gibbs_update_step_bgmcompare ( * - This function runs entirely in C++ and is wrapped for parallel execution * via `GibbsCompareChainRunner`. */ -SamplerOutput run_gibbs_sampler_bgmCompare( +bgmCompareOutput run_gibbs_sampler_bgmCompare( int chain_id, arma::imat observations, const int num_groups, @@ -1763,7 +1761,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare( } } - SamplerOutput out; + bgmCompareOutput out; out.chain_id = chain_id; out.main_samples = main_effect_samples; out.pairwise_samples = pairwise_effect_samples; @@ -1776,7 +1774,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare( } else { out.indicator_samples = arma::imat(); } - out.userInterrupt = userInterrupt; + out.userInterrupt = userInterrupt; return out; } \ No newline at end of file diff --git a/src/bgmCompare_sampler.h b/src/bgmCompare/bgmCompare_sampler.h similarity index 91% rename from src/bgmCompare_sampler.h rename to src/bgmCompare/bgmCompare_sampler.h index 781d7f0d..83bcdbaa 100644 --- a/src/bgmCompare_sampler.h +++ b/src/bgmCompare/bgmCompare_sampler.h @@ -1,14 +1,14 @@ #pragma once #include -#include "common_helpers.h" +#include "utils/common_helpers.h" +#include "bgmCompare/bgmCompare_output.h" #include -struct SamplerOutput; struct SafeRNG; class ProgressManager; -SamplerOutput run_gibbs_sampler_bgmCompare( +bgmCompareOutput run_gibbs_sampler_bgmCompare( int chain_id, arma::imat observations, const int num_groups, diff --git a/src/bgmCompare_parallel.cpp b/src/bgmCompare_interface.cpp similarity index 95% rename from src/bgmCompare_parallel.cpp rename to src/bgmCompare_interface.cpp index 24464023..5b0fa983 100644 --- a/src/bgmCompare_parallel.cpp +++ b/src/bgmCompare_interface.cpp @@ -1,15 +1,15 @@ // [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]] #include -#include "bgmCompare_sampler.h" -#include "rng_utils.h" // must be included before RcppParallel +#include "bgmCompare/bgmCompare_sampler.h" +#include "rng/rng_utils.h" // must be included before RcppParallel #include #include #include #include -#include "progress_manager.h" -#include "sampler_output.h" -#include "mcmc_adaptation.h" -#include "common_helpers.h" +#include "utils/progress_manager.h" +#include "bgmCompare/bgmCompare_output.h" +#include "mcmc/mcmc_adaptation.h" +#include "utils/common_helpers.h" using namespace RcppParallel; @@ -22,24 +22,19 @@ using namespace RcppParallel; * - error: True if the chain terminated with an error, false otherwise. * - error_msg: Error message if an error occurred (empty if none). * - chain_id: Integer identifier for the chain (1-based). - * - result: SamplerOutput object containing chain results + * - result: bgmCompareOutput object containing chain results * (samples, diagnostics, metadata). * * Usage: * - Used in parallel execution to collect results from each chain. * - Checked after execution to propagate errors or assemble outputs * into an R-accessible list. - * - * Notes: - * - This struct mirrors `ChainResult` from bgm, but stores a - * `SamplerOutput` instead of an `Rcpp::List`. - * - If `error == true`, the `result` field should be ignored. */ -struct ChainResultCompare { +struct bgmCompareChainResult { bool error; std::string error_msg; int chain_id; - SamplerOutput result; + bgmCompareOutput result; }; @@ -90,7 +85,7 @@ struct ChainResultCompare { * - hmc_num_leapfrogs: Number of leapfrog steps (HMC). * * Output: - * - results: Vector of `ChainResultCompare` objects, one per chain, filled in place. + * - results: Vector of `bgmCompareChainResult` objects, one per chain, filled in place. * * Notes: * - Each worker instance is shared across threads but invoked with different @@ -135,7 +130,7 @@ struct GibbsCompareChainRunner : public Worker { const int hmc_num_leapfrogs; ProgressManager& pm; // output - std::vector& results; + std::vector& results; GibbsCompareChainRunner( const arma::imat& observations_master, @@ -172,7 +167,7 @@ struct GibbsCompareChainRunner : public Worker { const UpdateMethod update_method, const int hmc_num_leapfrogs, ProgressManager& pm, - std::vector& results + std::vector& results ) : observations_master(observations_master), num_groups(num_groups), @@ -213,7 +208,7 @@ struct GibbsCompareChainRunner : public Worker { void operator()(std::size_t begin, std::size_t end) { for (std::size_t i = begin; i < end; ++i) { - ChainResultCompare out; + bgmCompareChainResult out; out.chain_id = static_cast(i + 1); out.error = false; @@ -229,7 +224,7 @@ struct GibbsCompareChainRunner : public Worker { arma::imat observations = observations_master; // run sampler (pure C++) - SamplerOutput result = run_gibbs_sampler_bgmCompare( + bgmCompareOutput result = run_gibbs_sampler_bgmCompare( out.chain_id, observations, num_groups, @@ -386,7 +381,7 @@ Rcpp::List run_bgmCompare_parallel( int hmc_num_leapfrogs, int progress_type ) { - std::vector results(num_chains); + std::vector results(num_chains); // per-chain seeds std::vector chain_rngs(num_chains); diff --git a/src/bgm_parallel.cpp b/src/bgm_interface.cpp similarity index 84% rename from src/bgm_parallel.cpp rename to src/bgm_interface.cpp index f0e6c941..86834b16 100644 --- a/src/bgm_parallel.cpp +++ b/src/bgm_interface.cpp @@ -1,20 +1,43 @@ // [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]] #include -#include "rng_utils.h" // must be included before RcppParallel +#include "rng/rng_utils.h" // must be included before RcppParallel #include -#include "bgm_sampler.h" +#include "bgm/bgm_sampler.h" +#include "bgm/bgm_output.h" #include #include #include -#include "progress_manager.h" -#include "mcmc_adaptation.h" -#include "common_helpers.h" -#include "chainResults.h" +#include "mcmc/mcmc_adaptation.h" +#include "utils/progress_manager.h" +#include "utils/common_helpers.h" using namespace RcppParallel; +/** + * Container for the result of a single MCMC chain (bgm model). + * + * Fields: + * - error: True if the chain terminated with an error, false otherwise. + * - error_msg: Error message if an error occurred (empty if none). + * - chain_id: Integer identifier for the chain (1-based). + * - result: bgmOutput object containing chain results + * (samples, diagnostics, metadata). + * + * Usage: + * - Used in parallel execution to collect results from each chain. + * - Checked after execution to propagate errors or assemble outputs + * into an R-accessible list. + */ +struct bgmChainResult { + bool error; + std::string error_msg; + int chain_id; + bgmOutput result; +}; + + /** * Worker struct for running a single Gibbs sampling chain in parallel (bgm model). @@ -40,7 +63,7 @@ using namespace RcppParallel; struct GibbsChainRunner : public Worker { const arma::imat& observations; const arma::ivec& num_categories; - double pairwise_scale; + double pairwise_scale; const EdgePrior edge_prior; const arma::mat& inclusion_probability; double beta_bernoulli_alpha; @@ -74,12 +97,12 @@ struct GibbsChainRunner : public Worker { ProgressManager& pm; // output buffer - std::vector& results; + std::vector& results; GibbsChainRunner( const arma::imat& observations, const arma::ivec& num_categories, - double pairwise_scale, + double pairwise_scale, const EdgePrior edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, @@ -109,11 +132,11 @@ struct GibbsChainRunner : public Worker { bool learn_mass_matrix, const std::vector& chain_rngs, ProgressManager& pm, - std::vector& results + std::vector& results ) : observations(observations), num_categories(num_categories), - pairwise_scale( pairwise_scale), + pairwise_scale( pairwise_scale), edge_prior(edge_prior), inclusion_probability(inclusion_probability), beta_bernoulli_alpha(beta_bernoulli_alpha), @@ -148,16 +171,16 @@ struct GibbsChainRunner : public Worker { void operator()(std::size_t begin, std::size_t end) { for (std::size_t i = begin; i < end; ++i) { - - ChainResult& chain_result = results[i]; + bgmChainResult chain_result; chain_result.chain_id = static_cast(i + 1); chain_result.error = false; - SafeRNG rng = chain_rngs[i]; try { + // per-chain RNG + SafeRNG rng = chain_rngs[i]; - run_gibbs_sampler_bgm( - chain_result, + bgmOutput result = run_gibbs_sampler_bgm( + chain_result.chain_id, observations, num_categories, pairwise_scale, @@ -192,6 +215,8 @@ struct GibbsChainRunner : public Worker { pm ); + chain_result.result = result; + } catch (std::exception& e) { chain_result.error = true; chain_result.error_msg = e.what(); @@ -199,6 +224,8 @@ struct GibbsChainRunner : public Worker { chain_result.error = true; chain_result.error_msg = "Unknown error"; } + + results[i] = chain_result; } } }; @@ -288,7 +315,7 @@ Rcpp::List run_bgm_parallel( int seed, int progress_type ) { - std::vector results(num_chains); + std::vector results(num_chains); // Prepare one independent RNG per chain via jump() std::vector chain_rngs(num_chains); @@ -328,31 +355,32 @@ Rcpp::List run_bgm_parallel( Rcpp::Named("chain_id") = results[i].chain_id ); } else { - Rcpp::List chain_i; - chain_i["main_samples"] = results[i].main_effect_samples; - chain_i["pairwise_samples"] = results[i].pairwise_effect_samples; + const auto& r = results[i].result; + Rcpp::List chain_i; + chain_i["main_samples"] = r.main_samples; + chain_i["pairwise_samples"] = r.pairwise_samples; - if (update_method_enum == nuts) { - chain_i["treedepth__"] = results[i].treedepth_samples; - chain_i["divergent__"] = results[i].divergent_samples; - chain_i["energy__"] = results[i].energy_samples; - } + if (update_method_enum == nuts) { + chain_i["treedepth__"] = r.treedepth_samples; + chain_i["divergent__"] = r.divergent_samples; + chain_i["energy__"] = r.energy_samples; + } - if (edge_selection) { - chain_i["indicator_samples"] = results[i].indicator_samples; + if (edge_selection) { + chain_i["indicator_samples"] = r.indicator_samples; - if (edge_prior_enum == Stochastic_Block) - chain_i["allocations"] = results[i].allocation_samples; - } + if (edge_prior_enum == Stochastic_Block) + chain_i["allocations"] = r.allocation_samples; + } - chain_i["userInterrupt"] = results[i].userInterrupt; - chain_i["chain_id"] = results[i].chain_id; + chain_i["userInterrupt"] = r.userInterrupt; + chain_i["chain_id"] = r.chain_id; - output[i] = chain_i; + output[i] = chain_i; } } pm.finish(); return output; -} +} \ No newline at end of file diff --git a/src/chainResultNew.h b/src/chainResultNew.h new file mode 100644 index 00000000..fa269a54 --- /dev/null +++ b/src/chainResultNew.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +class ChainResultNew { + +public: + ChainResultNew() {} + + bool error = false, + userInterrupt = false; + std::string error_msg; + int chain_id; + + arma::mat samples; + + void reserve(const size_t param_dim, const size_t n_iter) { + samples.set_size(param_dim, n_iter); + } + void store_sample(const size_t iter, const arma::vec& sample) { + samples.col(iter) = sample; + } + + // arma::imat indicator_samples; + + // other samples + // arma::ivec treedepth_samples; + // arma::ivec divergent_samples; + // arma::vec energy_samples; + // arma::imat allocation_samples; +}; diff --git a/src/chainResults.h b/src/chainResults.h deleted file mode 100644 index 395e9d42..00000000 --- a/src/chainResults.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include -#include - -/** - * Container for the result of a single MCMC chain. - * - * Fields: - * - error: True if the chain terminated with an error, false otherwise. - * - error_msg: Error message in case of failure (empty if no error). - * - chain_id: Integer identifier for the chain (1-based). - * - result: Rcpp::List containing the chain’s outputs (samples, diagnostics, etc.). - * - * Usage: - * - Used in parallel samplers to collect per-chain results. - * - Checked after execution to propagate errors or assemble outputs into R. - */ -struct ChainResult { - bool error; - std::string error_msg; - int chain_id; - bool userInterrupt; - arma::mat main_effect_samples; - arma::mat pairwise_effect_samples; - arma::ivec treedepth_samples; - arma::ivec divergent_samples; - arma::vec energy_samples; - arma::imat indicator_samples; - arma::imat allocation_samples; -}; diff --git a/src/data_simulation.cpp b/src/data_simulation.cpp index 2e50c820..b4bf38c3 100644 --- a/src/data_simulation.cpp +++ b/src/data_simulation.cpp @@ -1,4 +1,4 @@ -#include "explog_switch.h" +#include "math/explog_switch.h" #include using namespace Rcpp; @@ -80,7 +80,7 @@ IntegerMatrix sample_bcomrf_gibbs(int no_states, NumericMatrix interactions, NumericMatrix thresholds, StringVector variable_type, - IntegerVector reference_category, + IntegerVector baseline_category, int iter) { IntegerMatrix observations(no_states, no_variables); @@ -118,21 +118,26 @@ IntegerMatrix sample_bcomrf_gibbs(int no_states, for(int person = 0; person < no_states; person++) { rest_score = 0.0; for(int vertex = 0; vertex < no_variables; vertex++) { - rest_score += observations(person, vertex) * - interactions(vertex, variable); + if(variable_type[vertex] != "blume-capel") { + rest_score += observations(person, vertex) * interactions(vertex, variable); + } else { + int ref = baseline_category[vertex]; + int obs = observations(person, vertex); + rest_score += (obs - ref) * interactions(vertex, variable); + } } if(variable_type[variable] == "blume-capel") { cumsum = 0.0; + int ref = baseline_category[variable]; for(int category = 0; category < no_categories[variable] + 1; category++) { + const int score = category - ref; //The linear term of the Blume-Capel variable - exponent = thresholds(variable, 0) * category; + exponent = thresholds(variable, 0) * score; //The quadratic term of the Blume-Capel variable - exponent += thresholds(variable, 1) * - (category - reference_category[variable]) * - (category - reference_category[variable]); + exponent += thresholds(variable, 1) * score * score; //The pairwise interactions - exponent += category * rest_score; + exponent += rest_score * score; cumsum += MY_EXP(exponent); probabilities[category] = cumsum; } diff --git a/src/custom_exp.cpp b/src/math/custom_exp.cpp similarity index 85% rename from src/custom_exp.cpp rename to src/math/custom_exp.cpp index e6fe3164..2abe8cb3 100644 --- a/src/custom_exp.cpp +++ b/src/math/custom_exp.cpp @@ -1,7 +1,6 @@ -#include "explog_switch.h" +#include "math/explog_switch.h" #include "Rcpp.h" -// [[Rcpp::export]] Rcpp::String get_explog_switch() { #if USE_CUSTOM_LOG return "custom"; @@ -10,7 +9,6 @@ Rcpp::String get_explog_switch() { #endif } -// [[Rcpp::export]] Rcpp::NumericVector rcpp_ieee754_exp(Rcpp::NumericVector x) { Rcpp::NumericVector y(x.size()); for (int i = 0; i < x.size(); i++) { @@ -19,7 +17,6 @@ Rcpp::NumericVector rcpp_ieee754_exp(Rcpp::NumericVector x) { return y; } -// [[Rcpp::export]] Rcpp::NumericVector rcpp_ieee754_log(Rcpp::NumericVector x) { Rcpp::NumericVector y(x.size()); for (int i = 0; i < x.size(); i++) { diff --git a/src/e_arma_exp.h b/src/math/e_arma_exp.h similarity index 100% rename from src/e_arma_exp.h rename to src/math/e_arma_exp.h diff --git a/src/e_exp.cpp b/src/math/e_exp.cpp similarity index 100% rename from src/e_exp.cpp rename to src/math/e_exp.cpp diff --git a/src/e_exp.h b/src/math/e_exp.h similarity index 100% rename from src/e_exp.h rename to src/math/e_exp.h diff --git a/src/explog_switch.h b/src/math/explog_switch.h similarity index 94% rename from src/explog_switch.h rename to src/math/explog_switch.h index 9c174cc4..3fce818f 100644 --- a/src/explog_switch.h +++ b/src/math/explog_switch.h @@ -20,8 +20,8 @@ #if USE_CUSTOM_LOG -#include "e_exp.h" -#include "e_arma_exp.h" +#include "math/e_exp.h" +#include "math/e_arma_exp.h" #define MY_EXP __ieee754_exp #define MY_LOG __ieee754_log diff --git a/src/mcmc_adaptation.h b/src/mcmc/mcmc_adaptation.h similarity index 99% rename from src/mcmc_adaptation.h rename to src/mcmc/mcmc_adaptation.h index 4aa172c9..cf6184ff 100644 --- a/src/mcmc_adaptation.h +++ b/src/mcmc/mcmc_adaptation.h @@ -5,9 +5,9 @@ #include #include #include -#include "mcmc_utils.h" -#include "mcmc_rwm.h" -#include "explog_switch.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/mcmc_rwm.h" +#include "math/explog_switch.h" class DualAveraging { public: diff --git a/src/mcmc_hmc.cpp b/src/mcmc/mcmc_hmc.cpp similarity index 89% rename from src/mcmc_hmc.cpp rename to src/mcmc/mcmc_hmc.cpp index e851a24a..adb3c022 100644 --- a/src/mcmc_hmc.cpp +++ b/src/mcmc/mcmc_hmc.cpp @@ -1,9 +1,9 @@ #include #include -#include "mcmc_hmc.h" -#include "mcmc_leapfrog.h" -#include "mcmc_utils.h" -#include "rng_utils.h" +#include "mcmc/mcmc_hmc.h" +#include "mcmc/mcmc_leapfrog.h" +#include "mcmc/mcmc_utils.h" +#include "rng/rng_utils.h" diff --git a/src/mcmc_hmc.h b/src/mcmc/mcmc_hmc.h similarity index 92% rename from src/mcmc_hmc.h rename to src/mcmc/mcmc_hmc.h index 6492cd56..32e9d4a1 100644 --- a/src/mcmc_hmc.h +++ b/src/mcmc/mcmc_hmc.h @@ -2,7 +2,7 @@ #include #include -#include "mcmc_utils.h" +#include "mcmc/mcmc_utils.h" struct SafeRNG; SamplerResult hmc_sampler( diff --git a/src/mcmc_leapfrog.cpp b/src/mcmc/mcmc_leapfrog.cpp similarity index 97% rename from src/mcmc_leapfrog.cpp rename to src/mcmc/mcmc_leapfrog.cpp index 6d5fb4d0..ceadcfe8 100644 --- a/src/mcmc_leapfrog.cpp +++ b/src/mcmc/mcmc_leapfrog.cpp @@ -1,7 +1,8 @@ #include #include -#include "mcmc_leapfrog.h" -#include "mcmc_memoization.h" +#include "mcmc/mcmc_leapfrog.h" +#include "mcmc/mcmc_memoization.h" + /** diff --git a/src/mcmc_leapfrog.h b/src/mcmc/mcmc_leapfrog.h similarity index 95% rename from src/mcmc_leapfrog.h rename to src/mcmc/mcmc_leapfrog.h index b69164f1..bd7a5397 100644 --- a/src/mcmc_leapfrog.h +++ b/src/mcmc/mcmc_leapfrog.h @@ -3,7 +3,7 @@ #include #include -#include "mcmc_memoization.h" +#include "mcmc/mcmc_memoization.h" diff --git a/src/mcmc_memoization.h b/src/mcmc/mcmc_memoization.h similarity index 100% rename from src/mcmc_memoization.h rename to src/mcmc/mcmc_memoization.h diff --git a/src/mcmc_nuts.cpp b/src/mcmc/mcmc_nuts.cpp similarity index 98% rename from src/mcmc_nuts.cpp rename to src/mcmc/mcmc_nuts.cpp index 96884dc6..2335fabc 100644 --- a/src/mcmc_nuts.cpp +++ b/src/mcmc/mcmc_nuts.cpp @@ -1,10 +1,10 @@ #include #include -#include "mcmc_leapfrog.h" -#include "mcmc_memoization.h" -#include "mcmc_nuts.h" -#include "mcmc_utils.h" -#include "rng_utils.h" +#include "mcmc/mcmc_leapfrog.h" +#include "mcmc/mcmc_memoization.h" +#include "mcmc/mcmc_nuts.h" +#include "mcmc/mcmc_utils.h" +#include "rng/rng_utils.h" /** diff --git a/src/mcmc_nuts.h b/src/mcmc/mcmc_nuts.h similarity index 94% rename from src/mcmc_nuts.h rename to src/mcmc/mcmc_nuts.h index 613ffc2b..f678cf4b 100644 --- a/src/mcmc_nuts.h +++ b/src/mcmc/mcmc_nuts.h @@ -3,9 +3,9 @@ #include #include #include -#include "mcmc_leapfrog.h" -#include "mcmc_memoization.h" -#include "mcmc_utils.h" +#include "mcmc/mcmc_leapfrog.h" +#include "mcmc/mcmc_memoization.h" +#include "mcmc/mcmc_utils.h" struct SafeRNG; diff --git a/src/mcmc_rwm.cpp b/src/mcmc/mcmc_rwm.cpp similarity index 93% rename from src/mcmc_rwm.cpp rename to src/mcmc/mcmc_rwm.cpp index a91dc8f4..a07ad105 100644 --- a/src/mcmc_rwm.cpp +++ b/src/mcmc/mcmc_rwm.cpp @@ -1,9 +1,9 @@ #include #include #include -#include "mcmc_utils.h" -#include "mcmc_rwm.h" -#include "rng_utils.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/mcmc_rwm.h" +#include "rng/rng_utils.h" /** diff --git a/src/mcmc_rwm.h b/src/mcmc/mcmc_rwm.h similarity index 89% rename from src/mcmc_rwm.h rename to src/mcmc/mcmc_rwm.h index 73227642..fcce0254 100644 --- a/src/mcmc_rwm.h +++ b/src/mcmc/mcmc_rwm.h @@ -3,7 +3,7 @@ #include #include #include -#include "mcmc_utils.h" +#include "mcmc/mcmc_utils.h" struct SafeRNG; diff --git a/src/mcmc_utils.cpp b/src/mcmc/mcmc_utils.cpp similarity index 96% rename from src/mcmc_utils.cpp rename to src/mcmc/mcmc_utils.cpp index e2104969..25c616c4 100644 --- a/src/mcmc_utils.cpp +++ b/src/mcmc/mcmc_utils.cpp @@ -1,9 +1,10 @@ #include #include #include -#include "mcmc_leapfrog.h" -#include "mcmc_utils.h" -#include "rng_utils.h" +#include "mcmc/mcmc_leapfrog.h" +#include "mcmc/mcmc_utils.h" +#include "rng/rng_utils.h" + /** diff --git a/src/mcmc_utils.h b/src/mcmc/mcmc_utils.h similarity index 99% rename from src/mcmc_utils.h rename to src/mcmc/mcmc_utils.h index 9c935706..c84c0393 100644 --- a/src/mcmc_utils.h +++ b/src/mcmc/mcmc_utils.h @@ -5,7 +5,7 @@ #include #include #include -#include "explog_switch.h" +#include "math/explog_switch.h" struct SafeRNG; // (only if didn’t already provide it under C++17) diff --git a/src/models/adaptiveMetropolis.h b/src/models/adaptiveMetropolis.h new file mode 100644 index 00000000..d8d9cb6c --- /dev/null +++ b/src/models/adaptiveMetropolis.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include + +class AdaptiveProposal { + +public: + + AdaptiveProposal(size_t num_params, size_t adaption_window = 50, double target_accept = 0.44) { + proposal_sds_ = arma::vec(num_params, arma::fill::ones) * 0.25; // Initial SD, need to tweak this somehow? + acceptance_counts_ = arma::ivec(num_params, arma::fill::zeros); + adaptation_window_ = adaption_window; + target_accept_ = target_accept; + } + + double get_proposal_sd(size_t param_index) const { + validate_index(param_index); + return proposal_sds_[param_index]; + } + + void update_proposal_sd(size_t param_index) { + + if (!adapting_) { + return; + } + + double current_sd = get_proposal_sd(param_index); + double observed_acceptance_probability = acceptance_counts_[param_index] / static_cast(iterations_ + 1); + double rm_weight = std::pow(iterations_, -decay_rate_); + + // Robbins-Monro update step + double updated_sd = current_sd + (observed_acceptance_probability - target_accept_) * rm_weight; + updated_sd = std::clamp(updated_sd, rm_lower_bound, rm_upper_bound); + + proposal_sds_(param_index) = updated_sd; + } + + void increment_accepts(size_t param_index) { + validate_index(param_index); + acceptance_counts_[param_index]++; + } + + void increment_iteration() { + iterations_++; + if (iterations_ >= adaptation_window_) { + adapting_ = false; + } + } + +private: + arma::vec proposal_sds_; + arma::ivec acceptance_counts_; + int iterations_ = 0, + adaptation_window_; + double target_accept_ = 0.44, + decay_rate_ = 0.75, + rm_lower_bound = 0.001, + rm_upper_bound = 2.0; + bool adapting_ = true; + + void validate_index(size_t index) const { + if (index >= proposal_sds_.n_elem) { + throw std::out_of_range("Parameter index out of range"); + } + } + +}; diff --git a/src/models/base_model.cpp b/src/models/base_model.cpp new file mode 100644 index 00000000..e69de29b diff --git a/src/models/base_model.h b/src/models/base_model.h new file mode 100644 index 00000000..10f443dc --- /dev/null +++ b/src/models/base_model.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include + +class BaseModel { +public: + virtual ~BaseModel() = default; + + // Capability queries + virtual bool has_gradient() const { return false; } + virtual bool has_adaptive_mh() const { return false; } + + // Core methods (to be overridden by derived classes) + virtual double logp(const arma::vec& parameters) = 0; + + virtual arma::vec gradient(const arma::vec& parameters) { + if (!has_gradient()) { + throw std::runtime_error("Gradient not implemented for this model"); + } + throw std::runtime_error("Gradient method must be implemented in derived class"); + } + + virtual std::pair logp_and_gradient( + const arma::vec& parameters) { + if (!has_gradient()) { + throw std::runtime_error("Gradient not implemented for this model"); + } + return {logp(parameters), gradient(parameters)}; + } + + // For Metropolis-Hastings (model handles parameter groups internally) + virtual void do_one_mh_step() { + throw std::runtime_error("do_one_mh_step method must be implemented in derived class"); + } + + virtual arma::vec get_vectorized_parameters() { + throw std::runtime_error("get_vectorized_parameters method must be implemented in derived class"); + } + + virtual arma::ivec get_vectorized_indicator_parameters() { + throw std::runtime_error("get_vectorized_indicator_parameters method must be implemented in derived class"); + } + + // Return dimensionality of the parameter space + virtual size_t parameter_dimension() const = 0; + + virtual void set_seed(int seed) { + throw std::runtime_error("set_seed method must be implemented in derived class"); + } + + virtual std::unique_ptr clone() const { + throw std::runtime_error("clone method must be implemented in derived class"); + } + + +protected: + BaseModel() = default; +}; diff --git a/src/models/ggm/cholupdate.cpp b/src/models/ggm/cholupdate.cpp new file mode 100644 index 00000000..2aed567e --- /dev/null +++ b/src/models/ggm/cholupdate.cpp @@ -0,0 +1,129 @@ +#include "models/ggm/cholupdate.h" + +extern "C" { + +// from mgcv: https://github.com/cran/mgcv/blob/1b6a4c8374612da27e36420b4459e93acb183f2d/src/mat.c#L1876-L1883 +static inline double hypote(double x, double y) { +/* stable computation of sqrt(x^2 + y^2) */ + double t; + x = fabs(x);y=fabs(y); + if (y>x) { t = x;x = y; y = t;} + if (x==0) return(y); else t = y/x; + return(x*sqrt(1+t*t)); +} /* hypote */ + +// from mgcv: https://github.com/cran/mgcv/blob/1b6a4c8374612da27e36420b4459e93acb183f2d/src/mat.c#L1956 +void chol_up(double *R,double *u, int *n,int *up,double *eps) { +/* Rank 1 update of a cholesky factor. Works as follows: + + [up=1] R'R + uu' = [u,R'][u,R']' = [u,R']Q'Q[u,R']', and then uses Givens rotations to + construct Q such that Q[u,R']' = [0,R1']'. Hence R1'R1 = R'R + uu'. The construction + operates from first column to last. + + [up=0] uses an almost identical sequence, but employs hyperbolic rotations + in place of Givens. See Golub and van Loan (2013, 4e 6.5.4) + + Givens rotations are of form [c,-s] where c = cos(theta), s = sin(theta). + [s,c] + + Assumes R upper triangular, and that it is OK to use first two columns + below diagonal as temporary strorage for Givens rotations (the storage is + needed to ensure algorithm is column oriented). + + For downdate returns a negative value in R[1] (R[1,0]) if not +ve definite. +*/ + double c0,s0,*c,*s,z,*x,z0,*c1; + int j,j1,n1; + n1 = *n - 1; + if (*up) for (j1=-1,j=0;j<*n;j++,u++,j1++) { /* loop over columns of R */ + z = *u; /* initial element of u */ + x = R + *n * j; /* current column */ + c = R + 2;s = R + *n + 2; /* Storage for first n-2 Givens rotations */ + for (c1=c+j1;c R[j,j] */ + z0 = hypote(z,*x); /* sqrt(z^2+R[j,j]^2) */ + c0 = *x/z0; s0 = z/z0; /* need to zero z */ + /* now apply this rotation and this column is finished (so no need to update z) */ + *x = s0 * z + c0 * *x; + } else for (j1=-1,j=0;j<*n;j++,u++,j1++) { /* loop over columns of R for down-dating */ + z = *u; /* initial element of u */ + x = R + *n * j; /* current column */ + c = R + 2;s = R + *n + 2; /* Storage for first n-2 hyperbolic rotations */ + for (c1=c+j1;c R[j,j] */ + z0 = z / *x; /* sqrt(z^2+R[j,j]^2) */ + if (fabs(z0)>=1) { /* downdate not +ve def */ + //Rprintf("j = %d d = %g ",j,z0); + if (*n>1) R[1] = -2.0; + return; /* signals error */ + } + if (z0 > 1 - *eps) z0 = 1 - *eps; + c0 = 1/sqrt(1-z0*z0);s0 = c0 * z0; + /* now apply this rotation and this column is finished (so no need to update z) */ + *x = -s0 * z + c0 * *x; + } + + /* now zero c and s storage */ + c = R + 2;s = R + *n + 2; + for (x = c + *n - 2;c + +void cholesky_update( arma::mat& R, arma::vec& u, double eps = 1e-12); +void cholesky_downdate(arma::mat& R, arma::vec& u, double eps = 1e-12); diff --git a/src/models/ggm/ggm_model.cpp b/src/models/ggm/ggm_model.cpp new file mode 100644 index 00000000..eaa50f57 --- /dev/null +++ b/src/models/ggm/ggm_model.cpp @@ -0,0 +1,644 @@ +#include "models/ggm/ggm_model.h" +#include "models/adaptiveMetropolis.h" +#include "rng/rng_utils.h" +#include "models/ggm/cholupdate.h" + +double GGMModel::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { + return(A(ii, jj) - A(ii, i) * A(jj, i) / A(i, i)); +} + +void GGMModel::get_constants(size_t i, size_t j) { + + // TODO: helper function? + double logdet_omega = get_log_det(phi_); + + double log_adj_omega_ii = logdet_omega + std::log(std::abs(inv_omega_(i, i))); + double log_adj_omega_ij = logdet_omega + std::log(std::abs(inv_omega_(i, j))); + double log_adj_omega_jj = logdet_omega + std::log(std::abs(inv_omega_(j, j))); + + double inv_omega_sub_j1j1 = compute_inv_submatrix_i(inv_omega_, i, j, j); + double log_abs_inv_omega_sub_jj = log_adj_omega_ii + std::log(std::abs(inv_omega_sub_j1j1)); + double Phi_q1q = (2 * std::signbit(inv_omega_(i, j)) - 1) * std::exp( + (log_adj_omega_ij - (log_adj_omega_jj + log_abs_inv_omega_sub_jj) / 2) + ); + double Phi_q1q1 = std::exp((log_adj_omega_jj - log_abs_inv_omega_sub_jj) / 2); + + constants_[1] = Phi_q1q; + constants_[2] = Phi_q1q1; + constants_[3] = omega_(i, j) - Phi_q1q * Phi_q1q1; + constants_[4] = Phi_q1q1; + constants_[5] = omega_(j, j) - Phi_q1q * Phi_q1q; + constants_[6] = constants_[5] + constants_[3] * constants_[3] / (constants_[4] * constants_[4]); + +} + +double GGMModel::R(const double x) const { + if (x == 0) { + return constants_[6]; + } else { + return constants_[5] + std::pow((x - constants_[3]) / constants_[4], 2); + } +} + +double GGMModel::get_log_det(arma::mat triangular_A) const { + // assume A is an (upper) triangular cholesky factor + // returns the log determinant of A'A + + // TODO: should we just do + // log_det(val, sign, trimatu(A))? + return 2 * arma::accu(arma::log(triangular_A.diag())); +} + +double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { + + double logdet_omega = get_log_det(phi); + // TODO: why not just dot(omega, suf_stat_)? + double trace_prod = arma::accu(omega % suf_stat_); + + double log_likelihood = n_ * (p_ * log(2 * arma::datum::pi) / 2 + logdet_omega / 2) - trace_prod / 2; + + return log_likelihood; +} + +double GGMModel::log_density_impl_edge(size_t i, size_t j) const { + + // this is the log likelihood ratio, not the full log likelihood like GGMModel::log_density_impl + + double Ui2 = omega_(i, j) - omega_prop_(i, j); + // only reached from R + // if (omega_(j, j) == omega_prop_(j, j)) { + // k = i; + // i = j; + // j = k; + // } + double Uj2 = (omega_(j, j) - omega_prop_(j, j)) / 2; + + + // W <- matrix(c(0, 1, 1, 0), 2, 2) + // U0 <- matrix(c(0, -1, Ui2, Uj2)) + // U <- matrix(0, nrow(aOmega), 2) + // U[c(i, j), 1] <- c(0, -1) + // U[c(i, j), 2] <- c(Ui2, Uj2) + // aOmega_prop - (aOmega + U %*% W %*% t(U)) + // det(aOmega_prop) - det(aOmega + U %*% W %*% t(U)) + // det(aOmega_prop) - det(W + t(U) %*% inv_aOmega %*% U) * det(W) * det(aOmega) + // below computes logdet(W + t(U) %*% inv_aOmega %*% U) directly (this is a 2x2 matrix) + + double cc11 = 0 + inv_omega_(j, j); + double cc12 = 1 - (inv_omega_(i, j) * Ui2 + inv_omega_(j, j) * Uj2); + double cc22 = 0 + Ui2 * Ui2 * inv_omega_(i, i) + 2 * Ui2 * Uj2 * inv_omega_(i, j) + Uj2 * Uj2 * inv_omega_(j, j); + + double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); + // logdet - (logdet(aOmega_prop) - logdet(aOmega)) + + double trace_prod = -2 * (suf_stat_(j, j) * Uj2 + suf_stat_(i, j) * Ui2); + + double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; + return log_likelihood_ratio; + +} + +double GGMModel::log_density_impl_diag(size_t j) const { + // same as above but for i == j, so Ui2 = 0 + double Uj2 = (omega_(j, j) - omega_prop_(j, j)) / 2; + + double cc11 = 0 + inv_omega_(j, j); + double cc12 = 1 - inv_omega_(j, j) * Uj2; + double cc22 = 0 + Uj2 * Uj2 * inv_omega_(j, j); + + double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); + double trace_prod = -2 * suf_stat_(j, j) * Uj2; + + // This function uses the fact that the determinant doesn't change during edge updates. + // double trace_prod = 0.0; + // // TODO: we only need one of the two lines below, but it's not entirely clear which one + // trace_prod += suf_stat_(j, j) * (omega_prop(j, j) - omega(j, j)); + // trace_prod += suf_stat_(i, i) * (omega_prop(i, i) - omega(i, i)); + // trace_prod += 2 * suf_stat_(i, j) * (omega_prop(i, j) - omega(i, j)); + // trace_prod - sum((aOmega_prop - aOmega) * SufStat) + + double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; + return log_likelihood_ratio; + +} + +void GGMModel::update_edge_parameter(size_t i, size_t j) { + + if (edge_indicators_(i, j) == 0) { + return; // Edge is not included; skip update + } + + get_constants(i, j); + double Phi_q1q = constants_[1]; + double Phi_q1q1 = constants_[2]; + + size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form + double proposal_sd = proposal_.get_proposal_sd(e); + + double phi_prop = rnorm(rng_, Phi_q1q, proposal_sd); + double omega_prop_q1q = constants_[3] + constants_[4] * phi_prop; + double omega_prop_qq = R(omega_prop_q1q); + + // form full proposal matrix for Omega + omega_prop_ = omega_; // TODO: needs to be a copy! + omega_prop_(i, j) = omega_prop_q1q; + omega_prop_(j, i) = omega_prop_q1q; + omega_prop_(j, j) = omega_prop_qq; + + // Rcpp::Rcout << "i: " << i << ", j: " << j << + // ", proposed phi: " << phi_prop << + // ", proposal_sd omega_ij: " << proposal_sd << + // ", proposed omega_ij: " << omega_prop_q1q << + // ", proposed omega_jj: " << omega_prop_qq << std::endl; + // constants_.print(Rcpp::Rcout, "Constants:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + + // arma::vec eigval = eig_sym(omega_prop_); + // if (arma::any(eigval <= 0)) { + // Rcpp::Rcout << "Warning: omega_prop_ is not positive definite for edge (" << i << ", " << j << ")" << std::endl; + + // Rcpp::Rcout << + // ", proposed phi: " << phi_prop << + // ", proposal_sd omega_ij: " << proposal_sd << + // ", proposed omega_ij: " << omega_prop_q1q << + // ", proposed omega_jj: " << omega_prop_qq << std::endl; + // constants_.print(Rcpp::Rcout, "Constants:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // omega_.print(Rcpp::Rcout, "Current omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); + + // } + + // double ln_alpha = log_density(omega_prop_) - log_density(); + double ln_alpha = log_density_impl_edge(i, j); + + // { + // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Current omega:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // } + // } + + ln_alpha += R::dcauchy(omega_prop_(i, j), 0.0, 2.5, true); + ln_alpha -= R::dcauchy(omega_(i, j), 0.0, 2.5, true); + + if (std::log(runif(rng_)) < ln_alpha) { + // accept proposal + proposal_.increment_accepts(e); + + double omega_ij_old = omega_(i, j); + double omega_jj_old = omega_(j, j); + + + omega_(i, j) = omega_prop_q1q; + omega_(j, i) = omega_prop_q1q; + omega_(j, j) = omega_prop_qq; + + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + + // // TODO: preallocate? + // // find v for low rank update + // arma::vec v1 = {0, -1}; + // arma::vec v2 = {omega_ij - omega_prop_(i, j), (omega_jj - omega_prop_(j, j)) / 2}; + + // arma::vec vf1 = arma::zeros(p_); + // arma::vec vf2 = arma::zeros(p_); + // vf1[i] = v1[0]; + // vf1[j] = v1[1]; + // vf2[i] = v2[0]; + // vf2[j] = v2[1]; + + // // we now have + // // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) + + // arma::vec u1 = (vf1 + vf2) / sqrt(2); + // arma::vec u2 = (vf1 - vf2) / sqrt(2); + + // // we now have + // // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2)) + // // and also + // // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) + + // // update phi (2x O(p^2)) + // cholesky_update(phi_, u1); + // cholesky_downdate(phi_, u2); + + // // update inverse (2x O(p^2)) + // arma::inv(inv_phi_, arma::trimatu(phi_)); + // inv_omega_ = inv_phi_ * inv_phi_.t(); + + } + + proposal_.update_proposal_sd(e); +} + +void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j) +{ + + v2_[0] = omega_ij_old - omega_prop_(i, j); + v2_[1] = (omega_jj_old - omega_prop_(j, j)) / 2; + + vf1_[i] = v1_[0]; + vf1_[j] = v1_[1]; + vf2_[i] = v2_[0]; + vf2_[j] = v2_[1]; + + // we now have + // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) + + u1_ = (vf1_ + vf2_) / sqrt(2); + u2_ = (vf1_ - vf2_) / sqrt(2); + + // we now have + // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2)) + // and also + // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) + + // update phi (2x O(p^2)) + cholesky_update(phi_, u1_); + cholesky_downdate(phi_, u2_); + + // update inverse (2x O(p^2)) + arma::inv(inv_phi_, arma::trimatu(phi_)); + inv_omega_ = inv_phi_ * inv_phi_.t(); + + // reset for next iteration + vf1_[i] = 0.0; + vf1_[j] = 0.0; + vf2_[i] = 0.0; + vf2_[j] = 0.0; + +} + +void GGMModel::update_diagonal_parameter(size_t i) { + // Implementation of diagonal parameter update + // 1-3) from before + double logdet_omega = get_log_det(phi_); + double logdet_omega_sub_ii = logdet_omega + std::log(inv_omega_(i, i)); + + size_t e = i * (i + 1) / 2 + i; // parameter index in vectorized form + double proposal_sd = proposal_.get_proposal_sd(e); + + double theta_curr = (logdet_omega - logdet_omega_sub_ii) / 2; + double theta_prop = rnorm(rng_, theta_curr, proposal_sd); + + //4) Replace and rebuild omega + omega_prop_ = omega_; + omega_prop_(i, i) = omega_(i, i) - std::exp(theta_curr) * std::exp(theta_curr) + std::exp(theta_prop) * std::exp(theta_prop); + + // Rcpp::Rcout << "i: " << i << + // ", current theta: " << theta_curr << + // ", proposed theta: " << theta_prop << + // ", proposal_sd: " << proposal_sd << std::endl; + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + + // 5) Acceptance ratio + // double ln_alpha = log_density(omega_prop_) - log_density(); + double ln_alpha = log_density_impl_diag(i); + // { + // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for diag (" << i << ", " << i << ")" << std::endl; + // // omega_.print(Rcpp::Rcout, "Current omega:"); + // // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // Rcpp::Rcout << "1e4 * diff: " << 10000 * (ln_alpha - ln_alpha_ref) << std::endl; + // } + // } + + ln_alpha += R::dgamma(exp(theta_prop), 1.0, 1.0, true); + ln_alpha -= R::dgamma(exp(theta_curr), 1.0, 1.0, true); + ln_alpha += theta_prop - theta_curr; // Jacobian adjustment ? + + if (std::log(runif(rng_)) < ln_alpha) { + + proposal_.increment_accepts(e); + + double omega_ii = omega_(i, i); + omega_(i, i) = omega_prop_(i, i); + + cholesky_update_after_diag(omega_ii, i); + + // arma::vec u(p_, arma::fill::zeros); + // double delta = omega_ii - omega_prop_(i, i); + // bool s = delta > 0; + // u(i) = std::sqrt(std::abs(delta)); + + + // if (s) + // cholesky_downdate(phi_, u); + // else + // cholesky_update(phi_, u); + + // // update inverse (2x O(p^2)) + // arma::inv(inv_phi_, arma::trimatu(phi_)); + // inv_omega_ = inv_phi_ * inv_phi_.t(); + + + } + + proposal_.update_proposal_sd(e); +} + +void GGMModel::cholesky_update_after_diag(double omega_ii_old, size_t i) +{ + + double delta = omega_ii_old - omega_prop_(i, i); + + bool s = delta > 0; + vf1_(i) = std::sqrt(std::abs(delta)); + + if (s) + cholesky_downdate(phi_, vf1_); + else + cholesky_update(phi_, vf1_); + + // update inverse (2x O(p^2)) + arma::inv(inv_phi_, arma::trimatu(phi_)); + inv_omega_ = inv_phi_ * inv_phi_.t(); + + // reset for next iteration + vf1_(i) = 0.0; +} + + +void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { + + size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form + double proposal_sd = proposal_.get_proposal_sd(e); + + if (edge_indicators_(i, j) == 1) { + // Propose to turn OFF the edge + omega_prop_ = omega_; + omega_prop_(i, j) = 0.0; + omega_prop_(j, i) = 0.0; + + // Update diagonal using R function with omega_ij = 0 + get_constants(i, j); + omega_prop_(j, j) = R(0.0); + + // double ln_alpha = log_density(omega_prop_) - log_density(); + double ln_alpha = log_density_impl_edge(i, j); + // { + // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for edge indicator (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Current omega:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // } + // } + + + ln_alpha += std::log(1.0 - prior_inclusion_prob_(i, j)) - std::log(prior_inclusion_prob_(i, j)); + + ln_alpha += R::dnorm(omega_(i, j) / constants_[4], 0.0, proposal_sd, true) - std::log(constants_[4]); + ln_alpha -= R::dcauchy(omega_(i, j), 0.0, 2.5, true); + + if (std::log(runif(rng_)) < ln_alpha) { + + // Store old values for Cholesky update + double omega_ij_old = omega_(i, j); + double omega_jj_old = omega_(j, j); + + // Update omega + omega_(i, j) = 0.0; + omega_(j, i) = 0.0; + omega_(j, j) = omega_prop_(j, j); + + // Update edge indicator + edge_indicators_(i, j) = 0; + edge_indicators_(j, i) = 0; + + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + // // Cholesky update vectors + // arma::vec v1 = {0, -1}; + // arma::vec v2 = {omega_ij_old - 0.0, (omega_jj_old - omega_(j, j)) / 2}; + + // arma::vec vf1 = arma::zeros(p_); + // arma::vec vf2 = arma::zeros(p_); + // vf1[i] = v1[0]; + // vf1[j] = v1[1]; + // vf2[i] = v2[0]; + // vf2[j] = v2[1]; + + // arma::vec u1 = (vf1 + vf2) / sqrt(2); + // arma::vec u2 = (vf1 - vf2) / sqrt(2); + + // // Update Cholesky factor + // cholesky_update(phi_, u1); + // cholesky_downdate(phi_, u2); + + // // Update inverse + // arma::inv(inv_phi_, arma::trimatu(phi_)); + // inv_omega_ = inv_phi_ * inv_phi_.t(); + } + + } else { + // Propose to turn ON the edge + double epsilon = rnorm(rng_, 0.0, proposal_sd); + + // Get constants for current state (with edge OFF) + get_constants(i, j); + double omega_prop_ij = constants_[4] * epsilon; + double omega_prop_jj = R(omega_prop_ij); + + omega_prop_ = omega_; + omega_prop_(i, j) = omega_prop_ij; + omega_prop_(j, i) = omega_prop_ij; + omega_prop_(j, j) = omega_prop_jj; + + // double ln_alpha = log_density(omega_prop_) - log_density(); + double ln_alpha = log_density_impl_edge(i, j); + // { + // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for edge indicator (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Current omega:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // } + // } + ln_alpha += std::log(prior_inclusion_prob_(i, j)) - std::log(1.0 - prior_inclusion_prob_(i, j)); + + // Prior change: add slab (Cauchy prior) + ln_alpha += R::dcauchy(omega_prop_ij, 0.0, 2.5, true); + + // Proposal term: proposed edge value given it was generated from truncated normal + ln_alpha -= R::dnorm(omega_prop_ij / constants_[4], 0.0, proposal_sd, true) - std::log(constants_[4]); + + // TODO: this can be factored out? + if (std::log(runif(rng_)) < ln_alpha) { + // Accept: turn ON the edge + proposal_.increment_accepts(e); + + // Store old values for Cholesky update + double omega_ij_old = omega_(i, j); + double omega_jj_old = omega_(j, j); + + // Update omega + omega_(i, j) = omega_prop_ij; + omega_(j, i) = omega_prop_ij; + omega_(j, j) = omega_prop_jj; + + // Update edge indicator + edge_indicators_(i, j) = 1; + edge_indicators_(j, i) = 1; + + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + // // Cholesky update vectors + // arma::vec v1 = {0, -1}; + // arma::vec v2 = {omega_ij_old - omega_(i, j), (omega_jj_old - omega_(j, j)) / 2}; + + // arma::vec vf1 = arma::zeros(p_); + // arma::vec vf2 = arma::zeros(p_); + // vf1[i] = v1[0]; + // vf1[j] = v1[1]; + // vf2[i] = v2[0]; + // vf2[j] = v2[1]; + + // arma::vec u1 = (vf1 + vf2) / sqrt(2); + // arma::vec u2 = (vf1 - vf2) / sqrt(2); + + // // Update Cholesky factor + // cholesky_update(phi_, u1); + // cholesky_downdate(phi_, u2); + + // // Update inverse + // arma::inv(inv_phi_, arma::trimatu(phi_)); + // inv_omega_ = inv_phi_ * inv_phi_.t(); + } + } +} + +void GGMModel::do_one_mh_step() { + + // Update off-diagonals (upper triangle) + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + // Rcpp::Rcout << "Updating edge parameter (" << i << ", " << j << ")" << std::endl; + update_edge_parameter(i, j); + // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // } + // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); + // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); + // } + } + } + + // Update diagonals + for (size_t i = 0; i < p_; ++i) { + // Rcpp::Rcout << "Updating diagonal parameter " << i << std::endl; + update_diagonal_parameter(i); + + // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating diagonal " << i << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating diagonal " << i << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // } + // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating diagonal " << i << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating diagonal " << i << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); + // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); + // } + } + + if (edge_selection_) { + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + // Rcpp::Rcout << "Between model move for edge (" << i << ", " << j << ")" << std::endl; + update_edge_indicator_parameter_pair(i, j); + // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // } + // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); + // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); + // } + } + } + } + + // could also be called in the main MCMC loop + proposal_.increment_iteration(); +} + + +GGMModel createGGMFromR( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection +) { + + if (inputFromR.containsElementNamed("n") && inputFromR.containsElementNamed("suf_stat")) { + int n = Rcpp::as(inputFromR["n"]); + arma::mat suf_stat = Rcpp::as(inputFromR["suf_stat"]); + return GGMModel( + n, + suf_stat, + prior_inclusion_prob, + initial_edge_indicators, + edge_selection + ); + } else if (inputFromR.containsElementNamed("X")) { + arma::mat X = Rcpp::as(inputFromR["X"]); + return GGMModel( + X, + prior_inclusion_prob, + initial_edge_indicators, + edge_selection + ); + } else { + throw std::invalid_argument("Input list must contain either 'X' or both 'n' and 'suf_stat'."); + } + +} diff --git a/src/models/ggm/ggm_model.h b/src/models/ggm/ggm_model.h new file mode 100644 index 00000000..7ade3be3 --- /dev/null +++ b/src/models/ggm/ggm_model.h @@ -0,0 +1,196 @@ +#pragma once + +#include +#include "models/base_model.h" +#include "models/adaptiveMetropolis.h" +#include "rng/rng_utils.h" + + +class GGMModel : public BaseModel { +public: + + GGMModel( + const arma::mat& X, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true + ) : n_(X.n_rows), + p_(X.n_cols), + dim_((p_ * (p_ + 1)) / 2), + suf_stat_(X.t() * X), + prior_inclusion_prob_(prior_inclusion_prob), + edge_selection_(edge_selection), + proposal_(AdaptiveProposal(dim_, 500)), + omega_(arma::eye(p_, p_)), + phi_(arma::eye(p_, p_)), + inv_phi_(arma::eye(p_, p_)), + inv_omega_(arma::eye(p_, p_)), + edge_indicators_(initial_edge_indicators), + vectorized_parameters_(dim_), + vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), + omega_prop_(arma::mat(p_, p_, arma::fill::none)), + constants_(6) + {} + + GGMModel( + const int n, + const arma::mat& suf_stat, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true + ) : n_(n), + p_(suf_stat.n_cols), + dim_((p_ * (p_ + 1)) / 2), + suf_stat_(suf_stat), + prior_inclusion_prob_(prior_inclusion_prob), + edge_selection_(edge_selection), + proposal_(AdaptiveProposal(dim_, 500)), + omega_(arma::eye(p_, p_)), + phi_(arma::eye(p_, p_)), + inv_phi_(arma::eye(p_, p_)), + inv_omega_(arma::eye(p_, p_)), + edge_indicators_(initial_edge_indicators), + vectorized_parameters_(dim_), + vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), + omega_prop_(arma::mat(p_, p_, arma::fill::none)), + constants_(6) + {} + + GGMModel(const GGMModel& other) + : BaseModel(other), + dim_(other.dim_), + suf_stat_(other.suf_stat_), + n_(other.n_), + p_(other.p_), + prior_inclusion_prob_(other.prior_inclusion_prob_), + edge_selection_(other.edge_selection_), + omega_(other.omega_), + phi_(other.phi_), + inv_phi_(other.inv_phi_), + inv_omega_(other.inv_omega_), + edge_indicators_(other.edge_indicators_), + vectorized_parameters_(other.vectorized_parameters_), + vectorized_indicator_parameters_(other.vectorized_indicator_parameters_), + proposal_(other.proposal_), + rng_(other.rng_), + omega_prop_(other.omega_prop_), + constants_(other.constants_) + {} + + // // rng_ = SafeRNG(123); + + // } + + void set_adaptive_proposal(AdaptiveProposal proposal) { + proposal_ = proposal; + } + + bool has_gradient() const { return false; } + bool has_adaptive_mh() const override { return true; } + + double logp(const arma::vec& parameters) override { + // Implement log probability computation + return 0.0; + } + + // TODO: this can be done more efficiently, no need for the Cholesky! + double log_density(const arma::mat& omega) const { return log_density_impl(omega, arma::chol(omega)); }; + double log_density() const { return log_density_impl(omega_, phi_); } + + void do_one_mh_step() override; + + size_t parameter_dimension() const override { + return dim_; + } + + void set_seed(int seed) override { + rng_ = SafeRNG(seed); + } + + arma::vec get_vectorized_parameters() override { + // upper triangle of omega_ + size_t e = 0; + for (size_t j = 0; j < p_; ++j) { + for (size_t i = 0; i <= j; ++i) { + vectorized_parameters_(e) = omega_(i, j); + ++e; + } + } + return vectorized_parameters_; + } + + arma::ivec get_vectorized_indicator_parameters() override { + // upper triangle of omega_ + size_t e = 0; + for (size_t j = 0; j < p_; ++j) { + for (size_t i = 0; i <= j; ++i) { + vectorized_indicator_parameters_(e) = edge_indicators_(i, j); + ++e; + } + } + return vectorized_indicator_parameters_; + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); // uses copy constructor + } + +private: + // data + size_t n_; + size_t p_; + size_t dim_; + arma::mat suf_stat_; + arma::mat prior_inclusion_prob_; + bool edge_selection_; + + // parameters + arma::mat omega_, phi_, inv_phi_, inv_omega_; + arma::imat edge_indicators_; + arma::vec vectorized_parameters_; + arma::ivec vectorized_indicator_parameters_; + + + AdaptiveProposal proposal_; + SafeRNG rng_; + + // internal helper variables + arma::mat omega_prop_; + arma::vec constants_; // Phi_q1q, Phi_q1q1, c[1], c[2], c[3], c[4] + + arma::vec v1_ = {0, -1}; + arma::vec v2_ = {0, 0}; + arma::vec vf1_ = arma::zeros(p_); + arma::vec vf2_ = arma::zeros(p_); + arma::vec u1_ = arma::zeros(p_); + arma::vec u2_ = arma::zeros(p_); + + // Parameter group updates with optimized likelihood evaluations + void update_edge_parameter(size_t i, size_t j); + void update_diagonal_parameter(size_t i); + void update_edge_indicator_parameter_pair(size_t i, size_t j); + + // Helper methods + void get_constants(size_t i, size_t j); + double compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const; + double R(const double x) const; + + double log_density_impl(const arma::mat& omega, const arma::mat& phi) const; + double log_density_impl_edge(size_t i, size_t j) const; + double log_density_impl_diag(size_t j) const; + double get_log_det(arma::mat triangular_A) const; + void cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j); + void cholesky_update_after_diag(double omega_ii_old, size_t i); + // double find_reasonable_step_size_edge(const arma::mat& omega, size_t i, size_t j); + // double find_reasonable_step_size_diag(const arma::mat& omega, size_t i); + // double edge_log_ratio(const arma::mat& omega, size_t i, size_t j, double proposal); + // double diag_log_ratio(const arma::mat& omega, size_t i, double proposal); +}; + + +GGMModel createGGMFromR( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true +); diff --git a/src/print_mutex.h b/src/print_mutex.h deleted file mode 100644 index bde27291..00000000 --- a/src/print_mutex.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef PRINT_MUTEX_H -#define PRINT_MUTEX_H - -#include - -inline tbb::mutex& get_print_mutex() { - static tbb::mutex m; - return m; -} - -#endif // PRINT_MUTEX_H diff --git a/src/sbm_edge_prior.cpp b/src/priors/sbm_edge_prior.cpp similarity index 91% rename from src/sbm_edge_prior.cpp rename to src/priors/sbm_edge_prior.cpp index 182201bc..b24d2f57 100644 --- a/src/sbm_edge_prior.cpp +++ b/src/priors/sbm_edge_prior.cpp @@ -1,6 +1,6 @@ #include -#include "rng_utils.h" -#include "explog_switch.h" +#include "rng/rng_utils.h" +#include "math/explog_switch.h" // ----------------------------------------------------------------------------| // The c++ code below is based on the R code accompanying the paper: @@ -58,30 +58,6 @@ arma::mat add_row_col_block_prob_matrix(arma::mat X, } -// ----------------------------------------------------------------------------| -// Compute partition coefficient for the MFM - SBM -// ----------------------------------------------------------------------------| -// [[Rcpp::export]] -arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, - double dirichlet_alpha, - arma::uword t_max, - double lambda) { - arma::vec log_Vn(t_max); - double r; - - for(arma::uword t = 0; t < t_max; t++) { - r = -INFINITY; // initialize log-coefficient at -Inf - for(arma::uword k = t; k <= 500; k++){ - arma::vec b_linspace_1 = arma::linspace(k-t+1,k+1,t+1); // numerator = b*(b-1)*...*(b-|C|+1) - arma::vec b_linspace_2 = arma::linspace((k+1)*dirichlet_alpha,(k+1)*dirichlet_alpha+no_variables-1, no_variables); // denominator b*e*(b*e+1)*...*(b*e+p-1) - double b = arma::accu(arma::log(b_linspace_1))-arma::accu(arma::log(b_linspace_2)) + R::dpois((k+1)-1, lambda, true); // sum(log(numerator)) - sum(log(denominator)) + log(P=(k+1|lambda)) - double m = std::max(b,r); // scaling factor for log-sum-exp formula - r = MY_LOG(MY_EXP(r-m) + MY_EXP(b-m)) + m; // update r using log-sum-exp formula to ensure numerical stability and avoid underflow - } - log_Vn(t) = r; - } - return log_Vn; -} // ----------------------------------------------------------------------------| // Compute log-likelihood for the MFM - SBM diff --git a/src/sbm_edge_prior.h b/src/priors/sbm_edge_prior.h similarity index 80% rename from src/sbm_edge_prior.h rename to src/priors/sbm_edge_prior.h index a782347b..b4a1821f 100644 --- a/src/sbm_edge_prior.h +++ b/src/priors/sbm_edge_prior.h @@ -4,13 +4,6 @@ struct SafeRNG; -// ----------------------------------------------------------------------------| -// Compute partition coefficient for the MFM - SBM -// ----------------------------------------------------------------------------| -arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, - double dirichlet_alpha, - arma::uword t_max, - double lambda); // ----------------------------------------------------------------------------| // Sample the block allocations for the MFM - SBM diff --git a/src/rng_utils.h b/src/rng/rng_utils.h similarity index 100% rename from src/rng_utils.h rename to src/rng/rng_utils.h diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp new file mode 100644 index 00000000..a332a545 --- /dev/null +++ b/src/sample_ggm.cpp @@ -0,0 +1,188 @@ +#include +#include +#include +#include +#include + +#include "models/ggm/ggm_model.h" +#include "utils/progress_manager.h" +#include "chainResultNew.h" + +void run_mcmc_sampler_single_thread( + ChainResultNew& chain_result, + BaseModel& model, + const int no_iter, + const int no_warmup, + const int chain_id, + ProgressManager& pm +) { + + chain_result.chain_id = chain_id + 1; + size_t i = 0; + for (size_t iter = 0; iter < no_iter + no_warmup; ++iter) { + + model.do_one_mh_step(); + + if (iter >= no_warmup) { + + chain_result.store_sample(i, model.get_vectorized_parameters()); + ++i; + } + + pm.update(chain_id); + if (pm.shouldExit()) { + chain_result.userInterrupt = true; + break; + } + } +} + +struct GGMChainRunner : public RcppParallel::Worker { + std::vector& results_; + std::vector>& models_; + size_t no_iter_; + size_t no_warmup_; + int seed_; + ProgressManager& pm_; + + GGMChainRunner( + std::vector& results, + std::vector>& models, + const size_t no_iter, + const size_t no_warmup, + const int seed, + ProgressManager& pm + ) : + results_(results), + models_(models), + no_iter_(no_iter), + no_warmup_(no_warmup), + seed_(seed), + pm_(pm) + {} + + void operator()(std::size_t begin, std::size_t end) { + for (std::size_t i = begin; i < end; ++i) { + + ChainResultNew& chain_result = results_[i]; + BaseModel& model = *models_[i]; + model.set_seed(seed_ + i); + try { + + run_mcmc_sampler_single_thread(chain_result, model, no_iter_, no_warmup_, i, pm_); + + } catch (std::exception& e) { + chain_result.error = true; + chain_result.error_msg = e.what(); + } catch (...) { + chain_result.error = true; + chain_result.error_msg = "Unknown error"; + } + } + } +}; + +void run_mcmc_sampler_threaded( + std::vector& results, + std::vector>& models, + const int no_iter, + const int no_warmup, + const int seed, + const int no_threads, + ProgressManager& pm +) { + + GGMChainRunner runner(results, models, no_iter, no_warmup, seed, pm); + tbb::global_control control(tbb::global_control::max_allowed_parallelism, no_threads); + RcppParallel::parallelFor(0, results.size(), runner); +} + + +std::vector run_mcmc_sampler( + BaseModel& model, + const int no_iter, + const int no_warmup, + const int no_chains, + const int seed, + const int no_threads, + ProgressManager& pm +) { + + Rcpp::Rcout << "Allocating results objects..." << std::endl; + std::vector results(no_chains); + for (size_t c = 0; c < no_chains; ++c) { + results[c].reserve(model.parameter_dimension(), no_iter); + } + + if (no_threads > 1) { + + Rcpp::Rcout << "Running multi-threaded MCMC sampling..." << std::endl; + std::vector> models; + models.reserve(no_chains); + for (size_t c = 0; c < no_chains; ++c) { + models.push_back(model.clone()); // deep copy via virtual clone + } + run_mcmc_sampler_threaded(results, models, no_iter, no_warmup, seed, no_threads, pm); + + } else { + + model.set_seed(seed); + Rcpp::Rcout << "Running single-threaded MCMC sampling..." << std::endl; + // TODO: this is actually not correct, each chain should have its own model object + // now chain 2 continues from chain 1 state + for (size_t c = 0; c < no_chains; ++c) { + run_mcmc_sampler_single_thread(results[c], model, no_iter, no_warmup, c, pm); + } + + } + return results; +} + +Rcpp::List convert_sampler_output_to_ggm_result(const std::vector& results) { + + Rcpp::List output(results.size()); + for (size_t i = 0; i < results.size(); ++i) { + + Rcpp::List chain_i; + chain_i["chain_id"] = results[i].chain_id; + if (results[i].error) { + chain_i["error"] = results[i].error_msg; + } else { + chain_i["samples"] = results[i].samples; + chain_i["userInterrupt"] = results[i].userInterrupt; + + } + output[i] = chain_i; + } + return output; +} + +// [[Rcpp::export]] +Rcpp::List sample_ggm( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const int no_iter, + const int no_warmup, + const int no_chains, + const bool edge_selection, + const int seed, + const int no_threads, + const int progress_type +) { + + // should be done dynamically + // also adaptation method should be specified differently + // GGMModel model(X, prior_inclusion_prob, initial_edge_indicators, edge_selection); + GGMModel model = createGGMFromR(inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); + + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + + std::vector output = run_mcmc_sampler(model, no_iter, no_warmup, no_chains, seed, no_threads, pm); + + Rcpp::List ggm_result = convert_sampler_output_to_ggm_result(output); + + pm.finish(); + + return ggm_result; +} \ No newline at end of file diff --git a/src/sbm_edge_prior_interface.cpp b/src/sbm_edge_prior_interface.cpp new file mode 100644 index 00000000..8e040b90 --- /dev/null +++ b/src/sbm_edge_prior_interface.cpp @@ -0,0 +1,34 @@ +#include +#include "math/explog_switch.h" + +// ----------------------------------------------------------------------------| +// The c++ code below is based on the R code accompanying the paper: +// Geng, J., Bhattacharya, A., & Pati, D. (2019). Probabilistic Community +// Detection With Unknown Number of Communities, Journal of the American +// Statistical Association, 114:526, 893-905, DOI:10.1080/01621459.2018.1458618 +// ----------------------------------------------------------------------------| + +// ----------------------------------------------------------------------------| +// Compute partition coefficient for the MFM - SBM +// ----------------------------------------------------------------------------| +// [[Rcpp::export]] +arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, + double dirichlet_alpha, + arma::uword t_max, + double lambda) { + arma::vec log_Vn(t_max); + double r; + + for(arma::uword t = 0; t < t_max; t++) { + r = -INFINITY; // initialize log-coefficient at -Inf + for(arma::uword k = t; k <= 500; k++){ + arma::vec b_linspace_1 = arma::linspace(k-t+1,k+1,t+1); // numerator = b*(b-1)*...*(b-|C|+1) + arma::vec b_linspace_2 = arma::linspace((k+1)*dirichlet_alpha,(k+1)*dirichlet_alpha+no_variables-1, no_variables); // denominator b*e*(b*e+1)*...*(b*e+p-1) + double b = arma::accu(ARMA_MY_LOG(b_linspace_1))-arma::accu(ARMA_MY_LOG(b_linspace_2)) + R::dpois((k+1)-1, lambda, true); // sum(log(numerator)) - sum(log(denominator)) + log(P=(k+1|lambda)) + double m = std::max(b,r); // scaling factor for log-sum-exp formula + r = MY_LOG(MY_EXP(r-m) + MY_EXP(b-m)) + m; // update r using log-sum-exp formula to ensure numerical stability and avoid underflow + } + log_Vn(t) = r; + } + return log_Vn; +} \ No newline at end of file diff --git a/src/sbm_edge_prior_interface.h b/src/sbm_edge_prior_interface.h new file mode 100644 index 00000000..ac746a90 --- /dev/null +++ b/src/sbm_edge_prior_interface.h @@ -0,0 +1,9 @@ +#include +#include "math/explog_switch.h" + + + +arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, + double dirichlet_alpha, + arma::uword t_max, + double lambda); \ No newline at end of file diff --git a/src/common_helpers.h b/src/utils/common_helpers.h similarity index 100% rename from src/common_helpers.h rename to src/utils/common_helpers.h diff --git a/src/utils/print_mutex.h b/src/utils/print_mutex.h new file mode 100644 index 00000000..12f9f21f --- /dev/null +++ b/src/utils/print_mutex.h @@ -0,0 +1,21 @@ +#ifndef PRINT_MUTEX_H +#define PRINT_MUTEX_H + +#include + +inline tbb::mutex& get_print_mutex() { + static tbb::mutex m; + return m; +} + +#endif // PRINT_MUTEX_H + +// Add this header to the parallel code you wish to print from +// + the below code to print in parallel code: +// +// { +// tbb::mutex::scoped_lock lock(get_print_mutex()); +// std::cout +// << "print " +// << std::endl; +// } \ No newline at end of file diff --git a/src/progress_manager.cpp b/src/utils/progress_manager.cpp similarity index 99% rename from src/progress_manager.cpp rename to src/utils/progress_manager.cpp index 5ad13fbd..83610a69 100644 --- a/src/progress_manager.cpp +++ b/src/utils/progress_manager.cpp @@ -1,4 +1,4 @@ -#include "progress_manager.h" +#include "utils/progress_manager.h" ProgressManager::ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_, int progress_type_, bool useUnicode_) : nChains(nChains_), nIter(nIter_ + nWarmup_), nWarmup(nWarmup_), printEvery(printEvery_), @@ -398,7 +398,7 @@ void ProgressManager::maybePadToLength(std::string& content) const { // }; -// // [[Rcpp::export]] +// // [[Rcpp::export]] // if uncommented, must move .cpp file to src/ for Rcpp to compile // void runMCMC_parallel(int nChains = 4, int nIter = 100, int nWarmup = 100, int progress_type = 2, bool useUnicode = false, // int delay = 20) { diff --git a/src/progress_manager.h b/src/utils/progress_manager.h similarity index 100% rename from src/progress_manager.h rename to src/utils/progress_manager.h diff --git a/src/utils/variable_helpers.h b/src/utils/variable_helpers.h new file mode 100644 index 00000000..7940d61c --- /dev/null +++ b/src/utils/variable_helpers.h @@ -0,0 +1,395 @@ +#include +#include "math/explog_switch.h" + + + +// ----------------------------------------------------------------------------- +// Compute a numerically stable sum of the form: +// +// denom = exp(-bound) + sum_{cat=0}^{K-1} exp(main_effect_param(cat) +// + (cat + 1) * residual_score - bound) +// +// but evaluated efficiently using precomputed exponentials: +// +// exp_r = exp(residual_score) +// exp_m = exp(main_effect_param) +// denom = exp(-bound) * ( 1 + sum_c exp_m[c] * exp_r^(c+1) ) +// +// If non-finite values arise (overflow, underflow, NaN), a safe fallback +// recomputes the naive version using direct exponentials. +// ---------------------------------------------------------------------------- +inline arma::vec compute_denom_ordinal(const arma::vec& residual, + const arma::vec& main_eff, + const arma::vec& bound) +{ + constexpr double EXP_BOUND = 709.0; + const int K = static_cast(main_eff.n_elem); + + // --- Binary shortcut (K == 1) --------------------------------------------- + if (K == 1) { + return ARMA_MY_EXP(-bound) + ARMA_MY_EXP(main_eff[0] + residual - bound); + } + + const arma::uword N = bound.n_elem; + arma::vec denom(N, arma::fill::none); + const arma::vec eM = ARMA_MY_EXP(main_eff); + + // Fast block: uses eB inside the loop (avoids intermediate overflow) + auto do_fast_block = [&](arma::uword i0, arma::uword i1) { + arma::vec r = residual.rows(i0, i1); + arma::vec b = bound.rows(i0, i1); + arma::vec eR = ARMA_MY_EXP(r); + arma::vec eB = ARMA_MY_EXP(-b); + arma::vec pow = eR; + + arma::vec d = eB; + for (int c = 0; c < K; ++c) { + d += eM[c] * pow % eB; + pow %= eR; + } + denom.rows(i0, i1) = d; + }; + + // Safe block: stabilized exponent; NO clamp here by design + auto do_safe_block = [&](arma::uword i0, arma::uword i1) { + arma::vec r = residual.rows(i0, i1); + arma::vec b = bound.rows(i0, i1); + + arma::vec d = ARMA_MY_EXP(-b); + for (int c = 0; c < K; ++c) { + arma::vec ex = main_eff[c] + (c + 1) * r - b; + d += ARMA_MY_EXP(ex); + } + denom.rows(i0, i1) = d; + }; + + // Single linear scan over contiguous runs + const double* bp = bound.memptr(); + arma::uword i = 0; + while (i < N) { + const bool fast = !(bp[i] < -EXP_BOUND || bp[i] > EXP_BOUND); + arma::uword j = i + 1; + while (j < N) { + const bool fast_j = !(bp[j] < -EXP_BOUND || bp[j] > EXP_BOUND); + if (fast_j != fast) break; + ++j; + } + if (fast) do_fast_block(i, j - 1); + else do_safe_block(i, j - 1); + i = j; + } + + return denom; +} + +// ----------------------------------------------------------------------------- +// Compute denom = Σ_c exp( θ(c) + c*r - b ), with +// θ(c) = lin_eff*(c-ref) + quad_eff*(c-ref)^2 +// b = max_c( θ(c) + c*r ) (vectorized) +// +// Two modes: +// +// FAST (preexp + power-chain): +// denom = Σ_c exp_theta[c] * exp(-b) * exp(r)^c +// Used only when all exponent terms are safe: +// |b| ≤ EXP_BOUND, +// underflow_bound ≥ -EXP_BOUND, +// num_cats*r - b ≤ EXP_BOUND. +// This guarantees the recursive pow-chain stays finite. +// +// SAFE (direct evaluation): +// denom = Σ_c exp(θ(c) + c*r - b) +// Used whenever any FAST-condition fails. Slower but always stable. +// +// FAST gives identical results when safe, otherwise SAFE is used. +// ----------------------------------------------------------------------------- +inline arma::vec compute_denom_blume_capel( + const arma::vec& residual, + const double lin_eff, + const double quad_eff, + const int ref, + const int num_cats, + arma::vec& b // update in place: per-person bound b[i] +) { + + constexpr double EXP_BOUND = 709.0; + const arma::uword N = residual.n_elem; + arma::vec denom(N); + + // ---- 1. Precompute theta_part[cat] and exp(theta_part) ---- + arma::vec cat = arma::regspace(0, num_cats); + arma::vec centered = cat - double(ref); + arma::vec theta = lin_eff * centered + quad_eff * arma::square(centered); + arma::vec exp_theta = ARMA_MY_EXP(theta); + + // ---- 2. Numerical bounds [b] ---- + b.set_size(N); + b.fill(theta[0]); + for (int c = 1; c <= num_cats; c++) + b = arma::max(b, theta[c] + double(c) * residual); + + // ---- 3. Bounds for the FAST power chain: c*r - b ---- + // For fixed i, c*r[i] - b[i] ranges between -b[i] and num_cats*r[i] - b[i]. + // We need max_c (c*r[i] - b[i]) <= EXP_BOUND to avoid overflow in pow. + arma::vec pow_bound = double(num_cats) * residual - b; + + // ---- 4. FAST BLOCK: Preexp + bounded power chain ---- + auto do_fast_block = [&](arma::uword i0, arma::uword i1) { + arma::vec r = residual.rows(i0, i1); + arma::vec bb = b.rows(i0, i1); + + arma::vec eR = ARMA_MY_EXP(r); // exp(r) + arma::vec pow = ARMA_MY_EXP(-bb); // start at cat=0 term: exp(0*r - b) + arma::vec d = exp_theta[0] * pow; + + for (int c = 1; c <= num_cats; c++) { + pow %= eR; // exp(c*r - b) + d += exp_theta[c] * pow; + } + denom.rows(i0, i1) = d; + }; + + // ---- 5. SAFE BLOCK: direct exp(theta[c] + c*r - b) ---- + auto do_safe_block = [&](arma::uword i0, arma::uword i1) { + arma::vec r = residual.rows(i0, i1); + arma::vec bb = b.rows(i0, i1); + + arma::vec d(bb.n_elem, arma::fill::zeros); + for (int c = 0; c <= num_cats; c++) { + arma::vec ex = theta[c] + double(c) * r - bb; + d += ARMA_MY_EXP(ex); + } + + + + denom.rows(i0, i1) = d; + }; + + // ---- 6. BLOCK SCAN: decide FAST vs SAFE per contiguous run ---- + const double* bp = b.memptr(); + const double* pp = pow_bound.memptr(); + + arma::uword i = 0; + while (i < N) { + const bool fast_i = (std::abs(bp[i]) <= EXP_BOUND) && (std::abs(pp[i]) <= EXP_BOUND); + + arma::uword j = i + 1; + while (j < N) { + const bool fast_j = (std::abs(bp[j]) <= EXP_BOUND) && (std::abs(pp[j]) <= EXP_BOUND); + if (fast_j != fast_i) break; + ++j; + } + + if (fast_i) do_fast_block(i, j - 1); + else do_safe_block(i, j - 1); + + i = j; + } + + return denom; +} + + + +/** + * Compute category probabilities in a numerically stable manner. + * + * Uses pre-exp or bounded formulations depending on the magnitude of `bound`. + * - If |bound| < 700: uses cheaper direct pre-exp computation + * - Else: clips bound at zero and applies stabilized scaling + * + * Empirical tests (see R/compare_prob_ratios.R) showed: + * - Clipping necessary for bound < -700 + * - Bounds improve stability when large + * + * Returns: + * probs: num_persons × (num_cats + 1) matrix of probabilities (row-normalized) + */ +inline arma::mat compute_probs_ordinal(const arma::vec& main_param, + const arma::vec& residual_score, + const arma::vec& bound, + int num_cats) +{ + constexpr double EXP_BOUND = 709.0; + const arma::uword N = bound.n_elem; + + if (num_cats == 1) { + arma::vec b = arma::clamp(bound, 0.0, arma::datum::inf); + arma::vec ex = main_param(0) + residual_score - b; + arma::vec t = ARMA_MY_EXP(ex); + arma::vec den = ARMA_MY_EXP(-b) + t; + arma::mat probs(N, 2, arma::fill::none); + probs.col(1) = t / den; + probs.col(0) = 1.0 - probs.col(1); + return probs; + } + + arma::mat probs(N, num_cats + 1, arma::fill::none); + const arma::vec eM = ARMA_MY_EXP(main_param); + + auto do_fast_block = [&](arma::uword i0, arma::uword i1) { + auto P = probs.rows(i0, i1).cols(1, num_cats); + arma::vec r = residual_score.rows(i0, i1); + arma::vec eR = ARMA_MY_EXP(r); + arma::vec pow = eR; + arma::vec den(P.n_rows, arma::fill::ones); + for (int c = 0; c < num_cats; c++) { + arma::vec term = eM[c] * pow; + P.col(c) = term; + den += term; + pow %= eR; + } + P.each_col() /= den; + }; + + auto do_safe_block = [&](arma::uword i0, arma::uword i1) { + auto P = probs.rows(i0, i1).cols(1, num_cats); + arma::vec r = residual_score.rows(i0, i1); + arma::vec b = arma::clamp(bound.rows(i0, i1), 0.0, arma::datum::inf); + arma::vec den = ARMA_MY_EXP(-b); + for (int c = 0; c < num_cats; c++) { + arma::vec ex = main_param(c) + (c + 1) * r - b; + arma::vec t = ARMA_MY_EXP(ex); + P.col(c) = t; + den += t; + } + P.each_col() /= den; + }; + + // Single linear scan; no std::abs + const double* bp = bound.memptr(); + arma::uword i = 0; + while (i < N) { + const bool fast = !(bp[i] < -EXP_BOUND || bp[i] > EXP_BOUND); + arma::uword j = i + 1; + while (j < N) { + const bool fast_j = !(bp[j] < -EXP_BOUND || bp[j] > EXP_BOUND); + if (fast_j != fast) break; + j++; + } + if (fast) do_fast_block(i, j - 1); + else do_safe_block(i, j - 1); + i = j; + } + + probs.col(0) = 1.0 - arma::sum(probs.cols(1, num_cats), 1); + return probs; +} + + + +// ----------------------------------------------------------------------------- +// Blume–Capel probabilities, numerically stable via FAST/SAFE split. +// +// Model: +// θ(c) = lin_eff * (c - ref) + quad_eff * (c - ref)^2, c = 0..num_cats +// exps_i(c) = θ(c) + c * r_i +// b_i = max_c exps_i(c) +// +// Probabilities: +// p_i(c) ∝ exp( exps_i(c) - b_i ) +// +// FAST (preexp + power-chain, same bounds as compute_denom_blume_capel): +// used when |b_i| ≤ EXP_BOUND and pow_bound_i = num_cats * r_i - b_i ≤ EXP_BOUND +// +// SAFE (direct): +// used otherwise: direct exp(θ(c) + c * r_i - b_i) +// +// Under these conditions, denom is finite and > 0, so no one-hot fallback. +// ----------------------------------------------------------------------------- +inline arma::mat compute_probs_blume_capel(const arma::vec& residual, + const double lin_eff, + const double quad_eff, + const int ref, + const int num_cats, + arma::vec& b) // updated in place +{ + constexpr double EXP_BOUND = 709.0; + + const arma::uword N = residual.n_elem; + arma::mat probs(N, num_cats + 1, arma::fill::none); + + // 1. Precompute θ(c) and exp(θ(c)) + arma::vec cat = arma::regspace(0, num_cats); + arma::vec centered = cat - double(ref); + arma::vec theta = lin_eff * centered + quad_eff * arma::square(centered); + arma::vec exp_theta = ARMA_MY_EXP(theta); + + // 2. Compute bounds b[i] = max_c (θ(c) + c * r_i) + b.set_size(N); + b.fill(theta[0]); + for (int c = 1; c <= num_cats; ++c) { + b = arma::max(b, theta[c] + double(c) * residual); + } + + // 3. Bound for the power chain: max_c (c * r_i - b_i) = num_cats * r_i - b_i + arma::vec pow_bound = double(num_cats) * residual - b; + + // FAST block: preexp + bounded power chain + auto do_fast_block = [&](arma::uword i0, arma::uword i1) { + auto P = probs.rows(i0, i1); + arma::vec r = residual.rows(i0, i1); + arma::vec bb = b.rows(i0, i1); + const arma::uword B = bb.n_elem; + + arma::vec eR = ARMA_MY_EXP(r); // exp(r_i) + arma::vec pow = ARMA_MY_EXP(-bb); // exp(0 * r_i - b_i) + arma::vec denom(B, arma::fill::zeros); + + // c = 0 + arma::vec col0 = exp_theta[0] * pow; + P.col(0) = col0; + denom += col0; + + // c = 1..num_cats + for (int c = 1; c <= num_cats; ++c) { + pow %= eR; // exp(c * r_i - b_i) + arma::vec col = exp_theta[c] * pow; + P.col(c) = col; + denom += col; + } + + P.each_col() /= denom; + }; + + // SAFE block: direct exp(θ(c) + c * r_i - b_i) + auto do_safe_block = [&](arma::uword i0, arma::uword i1) { + auto P = probs.rows(i0, i1); + arma::vec r = residual.rows(i0, i1); + arma::vec bb = b.rows(i0, i1); + const arma::uword B = bb.n_elem; + arma::vec denom(B, arma::fill::zeros); + + for (int c = 0; c <= num_cats; ++c) { + arma::vec ex = theta[c] + double(c) * r - bb; + arma::vec col = ARMA_MY_EXP(ex); + P.col(c) = col; + denom += col; + } + P.each_col() /= denom; + }; + + // 4. Single linear scan over contiguous FAST/SAFE runs (same as denom) + const double* bp = b.memptr(); + const double* pp = pow_bound.memptr(); + arma::uword i = 0; + while (i < N) { + const bool fast_i = + (std::abs(bp[i]) <= EXP_BOUND) && (std::abs(pp[i]) <= EXP_BOUND); + + arma::uword j = i + 1; + while (j < N) { + const bool fast_j = + (std::abs(bp[j]) <= EXP_BOUND) && (std::abs(pp[j]) <= EXP_BOUND); + if (fast_j != fast_i) break; + j++; + } + + if (fast_i) do_fast_block(i, j - 1); + else do_safe_block(i, j - 1); + + i = j; + } + + return probs; +} \ No newline at end of file diff --git a/test_ggm.R b/test_ggm.R new file mode 100644 index 00000000..5f1995ca --- /dev/null +++ b/test_ggm.R @@ -0,0 +1,113 @@ +library(bgms) + +# Dimension and true precision +p <- 10 + +adj <- matrix(0, nrow = p, ncol = p) +adj[lower.tri(adj)] <- rbinom(p * (p - 1) / 2, size = 1, prob = 0.3) +adj <- adj + t(adj) +# qgraph::qgraph(adj) +Omega <- BDgraph::rgwish(1, adj = adj, b = p + sample(0:p, 1), D = diag(p)) +Sigma <- solve(Omega) +zapsmall(Omega) + +# Data +n <- 1e3 +x <- mvtnorm::rmvnorm(n = n, mean = rep(0, p), sigma = Sigma) + + +# ---- Run MCMC with warmup and sampling ------------------------------------ + +# debugonce(mbgms:::bgm_gaussian) +sampling_results <- bgms:::sample_ggm( + X = x, + prior_inclusion_prob = matrix(.5, p, p), + initial_edge_indicators = adj, + no_iter = 500, + no_warmup = 500, + no_chains = 3, + edge_selection = FALSE, + no_threads = 1, + seed = 123, + progress_type = 1 +) + +true_values <- zapsmall(Omega[upper.tri(Omega, TRUE)]) +posterior_means <- rowMeans(sampling_results[[2]]$samples) +cbind(true_values, posterior_means) + +plot(true_values, posterior_means) +abline(0, 1) + +sampling_results2 <- bgms:::sample_ggm( + X = x, + prior_inclusion_prob = matrix(.5, p, p), + initial_edge_indicators = adj, + no_iter = 500, + no_warmup = 500, + no_chains = 3, + edge_selection = TRUE, + no_threads = 1, + seed = 123, + progress_type = 1 +) + +true_values <- zapsmall(Omega[upper.tri(Omega, TRUE)]) +posterior_means <- rowMeans(sampling_results2[[2]]$samples) + +plot(true_values, posterior_means) +abline(0, 1) + +plot(posterior_means, rowMeans(sampling_results2[[2]]$samples != 0)) + + +mmm <- matrix(c( + 1.6735, 0, 0, 0, 0, + 0, 1.0000, 0, 0, -3.4524, + 0, 0, 1.0000, 0, 0, + 0, 0, 0, 1.0000, 0, + 0, -3.4524, 0, 0, 9.6674 +), p, p) +mmm +chol(mmm) +base::isSymmetric(mmm) +eigen(mmm) + +profvis::profvis({ + sampling_results <- bgm_gaussian( + x = x, + n = n, + n_iter = 400, + n_warmup = 400, + n_phases = 10 + ) +}) + +# Extract results +aveOmega <- sampling_results$aveOmega +aveGamma <- sampling_results$aveGamma +aOmega <- sampling_results$aOmega +aGamma <- sampling_results$aGamma +prob <- sampling_results$prob +proposal_sd <- sampling_results$proposal_sd + +library(patchwork) +library(ggplot2) +df <- data.frame( + true = aveOmega[lower.tri(aveOmega)], + Omega[lower.tri(Omega)], + estimated = aveOmega[lower.tri(aveOmega)], + p_inclusion = aveGamma[lower.tri(aveGamma)] +) +p1 <- ggplot(df, aes(x = true, y = estimated)) + + geom_point(size = 5, alpha = 0.8, shape = 21, fill = "grey") + + geom_abline(slope = 1, intercept = 0, color = "grey") + + labs(x = "True Values Omega", y = "Estimated Values Omega (Posterior Mean)") +p2 <- ggplot(df, aes(x = estimated, y = p_inclusion)) + + geom_point(size = 5, alpha = 0.8, shape = 21, fill = "grey") + + labs( + x = "Estimated Values Omega (Posterior Mean)", + y = "Estimated Inclusion Probabilities" + ) +(p1 + p2) + plot_layout(ncol = 1) & theme_bw(base_size = 20) +