Skip to content

Commit 2deb255

Browse files
authored
newton raphson/ laplace approximation (#41)
1 parent 145f4e7 commit 2deb255

File tree

5 files changed

+352
-5
lines changed

5 files changed

+352
-5
lines changed

R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interact
99
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)
1010
}
1111

12+
optimize_log_pseudoposterior_interaction <- function(initial_value, pairwise_effects, main_effects, observations, num_categories, num_persons, variable1, variable2, proposed_state, current_state, residual_matrix, is_ordinal_variable, reference_category, interaction_scale) {
13+
.Call(`_bgms_optimize_log_pseudoposterior_interaction`, initial_value, pairwise_effects, main_effects, observations, num_categories, num_persons, variable1, variable2, proposed_state, current_state, residual_matrix, is_ordinal_variable, reference_category, interaction_scale)
14+
}
15+
1216
run_gibbs_sampler_for_bgm <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main = FALSE, save_pairwise = FALSE, save_indicator = FALSE, display_progress = FALSE, edge_selection = TRUE, update_method = "adaptive-metropolis") {
1317
.Call(`_bgms_run_gibbs_sampler_for_bgm`, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main, save_pairwise, save_indicator, display_progress, edge_selection, update_method)
1418
}

man/bgm.Rd

Lines changed: 20 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

showcase_newton_raphson.R

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
library(bgms)
2+
3+
pseudolikelihood_numerator <- function(thresholds, interactions, suffstats, seen, threshold_counts_without_0, P) {
4+
result <- 0.0
5+
6+
for (i in seq_len(P)) {
7+
for (u in seq_len(seen[i] - 1)) {
8+
result <- result + threshold_counts_without_0[i, u] * thresholds[i, u]
9+
}
10+
}
11+
12+
result <- result + sum(interactions * suffstats)
13+
14+
return(result)
15+
}
16+
17+
18+
pseudo_logposterior_full_aij2 <- function(a, i, j, thresholds, interactions, suffstats, seen, threshold_counts_without_0, X, P, N,
19+
prior_cauchy_scale = 2.5) {
20+
21+
interactions[i, j] <- interactions[j, i] <- a
22+
23+
pseudolikelihood_numerator(
24+
thresholds = thresholds,
25+
interactions = interactions,
26+
suffstats = suffstats,
27+
seen = seen,
28+
threshold_counts_without_0 = threshold_counts_without_0,
29+
P = P
30+
) +
31+
pseudolikelihood_denominator2(
32+
thresholds = thresholds,
33+
interactions = interactions,
34+
suffstats = suffstats,
35+
seen = seen,
36+
X = X,
37+
P = P,
38+
N = N,
39+
i0 = i,
40+
j0 = j
41+
) +
42+
sum(dcauchy(interactions[lower.tri(interactions)], 0, prior_cauchy_scale, log = TRUE))
43+
}
44+
45+
pseudolikelihood_denominator2 <- function(thresholds, interactions, suffstats, seen, X, P, N, i0, j0) {
46+
result <- 0.0
47+
48+
for (v in seq_len(N)) {
49+
for (i in c(i0, j0)) {
50+
temp1 <- c(crossprod(interactions[i, ], X[v, ]))
51+
52+
temp2 <- 1.0
53+
for (u in seq_len(seen[i] - 1)) {
54+
temp2 <- temp2 + exp(thresholds[i, u] + u * temp1)
55+
}
56+
57+
result <- result - log(temp2)
58+
}
59+
}
60+
61+
return(result)
62+
}
63+
64+
log_pseudolikelihood_full2 <- function(a, i, j, Mu, Sigma, iter, x, suffstats, seen, threshold_counts_without_0) {
65+
66+
67+
n <- nrow(x) # Number of observations
68+
p <- ncol(x)
69+
70+
MuIter <- Mu[iter, ]
71+
MuMat <- matrix(0, p, ncol(threshold_counts_without_0)) # Initialize matrix for thresholds
72+
idx <- 1
73+
for (ii in 1:p) {
74+
for (jj in 1:ncol(MuMat)) {
75+
MuMat[ii, jj] <- MuIter[idx]
76+
idx <- idx + 1# Fill matrix with threshold values
77+
}
78+
}
79+
SigmaIter <- Sigma[iter, ]
80+
SigmaMat = matrix(0, p, p) # Initialize matrix for interactions
81+
SigmaMat[lower.tri(SigmaMat)] = SigmaIter # Fill lower triangle with Sigma values
82+
SigmaMat = SigmaMat + t(SigmaMat) # Make symmetric
83+
84+
D = length(a) # Number of elements in a
85+
log_pl = numeric(length = D) # Initialize log pseudolikelihood vector
86+
87+
# colMax <- unname(matrixStats::colMaxs(x))
88+
# log_p <- numeric(length = max(colMax) + 1) # Initialize log probability vector
89+
90+
for (d in 1:D) {
91+
log_pl[d] = pseudo_logposterior_full_aij2(a[d], i, j, thresholds = MuMat, interactions = SigmaMat, X = x, N = n, P = p,
92+
seen = seen, suffstats = suffstats, threshold_counts_without_0 = threshold_counts_without_0)
93+
}
94+
95+
return(log_pl) # Return log pseudolikelihood
96+
}
97+
98+
99+
x0 = Wenchuan[1:50, 1:5] # Select the first 5 columns of Wenchuan dataset
100+
p = ncol(x0) # Get the number of variables (columns)
101+
102+
samples = bgm(x0, save = TRUE) # Run the bgm function and save samples
103+
Mu = samples$main_effect_samples # Extract threshold estimates
104+
Sigma = samples$pairwise_effect_samples # Extract interaction estimates
105+
106+
data = bgms:::reformat_data(x = x0,
107+
na_action = "listwise",
108+
variable_bool = rep(TRUE, p),
109+
reference_category = rep(1, p))
110+
111+
x = data$x # Extract reformatted data
112+
no_categories = data$no_categories # Get number of categories per variable
113+
no_categories = cumsum(no_categories) # Cumulative sum for indexing
114+
start = 1 + c(0, no_categories[-length(no_categories)]) # Start indices
115+
stop = no_categories # Stop indices
116+
117+
K <- max(x)
118+
threshold_counts_wench <- apply(x, 2, \(y) c(table(c(y, 0:K)) - 1))
119+
threshold_counts_without_0_wench <- apply(threshold_counts_wench, 2L, \(y) {
120+
c(y[y > 0], rep(0, sum(y == 0)))[-1L]
121+
})
122+
threshold_counts_without_0_wench <- t(matrix(threshold_counts_without_0_wench, K, p))
123+
124+
seen_wench <- unname(apply(x, 2, \(y) length(unique(y))))
125+
suffstats_wench <- unname(crossprod(x))
126+
127+
128+
i <- 2; j <- 1
129+
log_pseudolikelihood_full2(c(.2, .5), i, j, Mu, Sigma, iter = 10000, x = x,
130+
seen = seen_wench, suffstats = suffstats_wench,
131+
threshold_counts_without_0 = threshold_counts_without_0_wench)
132+
133+
optim_res <- optim(Sigma[i, j], function(a) {
134+
returnVal <- log_pseudolikelihood_full2(a, i, j, Mu, Sigma, iter = 10000, x = x,
135+
seen = seen_wench, suffstats = suffstats_wench,
136+
threshold_counts_without_0 = threshold_counts_without_0_wench)
137+
if (any(!is.finite(returnVal))) {
138+
for (i in seq_along(returnVal)) {
139+
# if (!is.finite(returnVal[i])) {
140+
# print(sprintf("a: %f, returnVal: %f", a[i], returnVal[i]))
141+
# }
142+
if (!is.finite(returnVal[i]) && returnVal[i] < 0) {
143+
returnVal[i] <- -.Machine$double.xmax
144+
}
145+
}
146+
}
147+
return(returnVal)
148+
}, method = "Brent", lower = -100, upper = 100, control = list(fnscale = -1, trace = 5))
149+
150+
151+
# setup arguments for C++
152+
iter <- 10000
153+
MuMat <- matrix(0, p, ncol(threshold_counts_without_0_wench)) # Initialize matrix for thresholds
154+
idx <- 1
155+
for (ii in 1:p) {
156+
for (jj in 1:ncol(MuMat)) {
157+
MuMat[ii, jj] <- Mu[iter, idx]
158+
idx <- idx + 1# Fill matrix with threshold values
159+
}
160+
}
161+
SigmaMat = matrix(0, p, p) # Initialize matrix for interactions
162+
SigmaMat[lower.tri(SigmaMat)] = Sigma[iter, ] # Fill lower triangle with Sigma values
163+
SigmaMat = SigmaMat + t(SigmaMat) # Make symmetric
164+
165+
166+
pairwise_effects <- SigmaMat
167+
main_effects <- MuMat
168+
169+
# const double
170+
initial_value <- Sigma[i, j]
171+
# const arma::mat&
172+
pairwise_effects <- SigmaMat
173+
# const arma::mat&
174+
main_effects <- MuMat
175+
# const arma::imat&
176+
observations <- x
177+
# const arma::ivec&
178+
num_categories <- seen_wench
179+
# const int
180+
num_persons <- nrow(x)
181+
# const int
182+
variable1 <- i
183+
# const int
184+
variable2 <- j
185+
# TODO: these two are unused?
186+
# const double
187+
proposed_state <- 0.0
188+
# const double
189+
current_state <- 0.0
190+
# const arma::mat&
191+
residual_matrix <- matrix(0, nrow(x), p)
192+
# const arma::uvec&
193+
is_ordinal_variable <- rep(1, p)
194+
# const arma::ivec&
195+
reference_category <- data$reference_category
196+
# const double
197+
interaction_scale <- 2.5
198+
199+
newton_raphson_x <- bgms:::optimize_log_pseudoposterior_interaction(
200+
initial_value = c(initial_value),
201+
pairwise_effects = pairwise_effects,
202+
main_effects = main_effects,
203+
observations = observations,
204+
num_categories = num_categories - 1,
205+
num_persons = num_persons,
206+
variable1 = variable1 - 1,
207+
variable2 = variable2 - 1,
208+
proposed_state = proposed_state,
209+
current_state = current_state,
210+
residual_matrix = residual_matrix,
211+
is_ordinal_variable = is_ordinal_variable,
212+
reference_category = reference_category,
213+
interaction_scale = interaction_scale
214+
)
215+
newton_raphson_fx <- log_pseudolikelihood_full2(newton_raphson, i, j, Mu, Sigma, iter = 10000, x = x,
216+
seen = seen_wench, suffstats = suffstats_wench,
217+
threshold_counts_without_0 = threshold_counts_without_0_wench)
218+
matrix(c(newton_raphson_x, newton_raphson_fx, optim_res$par, optim_res$value),
219+
nrow = 2, dimnames = list(c("x", "f(x)"), c("Newton-Raphson", "Optim")))

src/RcppExports.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,30 @@ BEGIN_RCPP
4545
return rcpp_result_gen;
4646
END_RCPP
4747
}
48+
// optimize_log_pseudoposterior_interaction
49+
double optimize_log_pseudoposterior_interaction(const double initial_value, arma::mat& pairwise_effects, const arma::mat& main_effects, const arma::imat& observations, const arma::ivec& num_categories, const int num_persons, const int variable1, const int variable2, const double proposed_state, const double current_state, const arma::mat& residual_matrix, const arma::uvec& is_ordinal_variable, const arma::ivec& reference_category, const double interaction_scale);
50+
RcppExport SEXP _bgms_optimize_log_pseudoposterior_interaction(SEXP initial_valueSEXP, SEXP pairwise_effectsSEXP, SEXP main_effectsSEXP, SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP num_personsSEXP, SEXP variable1SEXP, SEXP variable2SEXP, SEXP proposed_stateSEXP, SEXP current_stateSEXP, SEXP residual_matrixSEXP, SEXP is_ordinal_variableSEXP, SEXP reference_categorySEXP, SEXP interaction_scaleSEXP) {
51+
BEGIN_RCPP
52+
Rcpp::RObject rcpp_result_gen;
53+
Rcpp::RNGScope rcpp_rngScope_gen;
54+
Rcpp::traits::input_parameter< const double >::type initial_value(initial_valueSEXP);
55+
Rcpp::traits::input_parameter< arma::mat& >::type pairwise_effects(pairwise_effectsSEXP);
56+
Rcpp::traits::input_parameter< const arma::mat& >::type main_effects(main_effectsSEXP);
57+
Rcpp::traits::input_parameter< const arma::imat& >::type observations(observationsSEXP);
58+
Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP);
59+
Rcpp::traits::input_parameter< const int >::type num_persons(num_personsSEXP);
60+
Rcpp::traits::input_parameter< const int >::type variable1(variable1SEXP);
61+
Rcpp::traits::input_parameter< const int >::type variable2(variable2SEXP);
62+
Rcpp::traits::input_parameter< const double >::type proposed_state(proposed_stateSEXP);
63+
Rcpp::traits::input_parameter< const double >::type current_state(current_stateSEXP);
64+
Rcpp::traits::input_parameter< const arma::mat& >::type residual_matrix(residual_matrixSEXP);
65+
Rcpp::traits::input_parameter< const arma::uvec& >::type is_ordinal_variable(is_ordinal_variableSEXP);
66+
Rcpp::traits::input_parameter< const arma::ivec& >::type reference_category(reference_categorySEXP);
67+
Rcpp::traits::input_parameter< const double >::type interaction_scale(interaction_scaleSEXP);
68+
rcpp_result_gen = Rcpp::wrap(optimize_log_pseudoposterior_interaction(initial_value, pairwise_effects, main_effects, observations, num_categories, num_persons, variable1, variable2, proposed_state, current_state, residual_matrix, is_ordinal_variable, reference_category, interaction_scale));
69+
return rcpp_result_gen;
70+
END_RCPP
71+
}
4872
// run_gibbs_sampler_for_bgm
4973
List run_gibbs_sampler_for_bgm(arma::imat& observations, const arma::ivec& num_categories, const double interaction_scale, const String& edge_prior, arma::mat& inclusion_probability, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double dirichlet_alpha, const double lambda, const arma::imat& interaction_index_matrix, const int iter, const int burnin, arma::imat& num_obs_categories, arma::imat& sufficient_blume_capel, const double threshold_alpha, const double threshold_beta, const bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& reference_category, const bool save_main, const bool save_pairwise, const bool save_indicator, const bool display_progress, bool edge_selection, const std::string& update_method);
5074
RcppExport SEXP _bgms_run_gibbs_sampler_for_bgm(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP interaction_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP threshold_alphaSEXP, SEXP threshold_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP reference_categorySEXP, SEXP save_mainSEXP, SEXP save_pairwiseSEXP, SEXP save_indicatorSEXP, SEXP display_progressSEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP) {
@@ -143,6 +167,7 @@ END_RCPP
143167
static const R_CallMethodDef CallEntries[] = {
144168
{"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6},
145169
{"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8},
170+
{"_bgms_optimize_log_pseudoposterior_interaction", (DL_FUNC) &_bgms_optimize_log_pseudoposterior_interaction, 14},
146171
{"_bgms_run_gibbs_sampler_for_bgm", (DL_FUNC) &_bgms_run_gibbs_sampler_for_bgm, 26},
147172
{"_bgms_compare_anova_gibbs_sampler", (DL_FUNC) &_bgms_compare_anova_gibbs_sampler, 34},
148173
{"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4},

0 commit comments

Comments
 (0)