Skip to content

Commit 23454ce

Browse files
Checkout
1 parent 8c3a173 commit 23454ce

File tree

1 file changed

+22
-75
lines changed

1 file changed

+22
-75
lines changed

src/gibbs_functions.cpp

Lines changed: 22 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -832,16 +832,11 @@ inline void update_fisher_preconditioner (
832832
* The step size is trace-scaled to match the average proposal magnitude across
833833
* dimensions, following Proposition 1 from Titsias (2023).
834834
*
835-
* If the proposal is accepted, the main_effects matrix is updated in-place and
836-
* the gradient cache for the interaction update is invalidated by setting
837-
* `gradient_valid = false`.
838-
*
839835
* Modifies:
840836
* - main_effects
841837
* - step_size
842838
* - dual_averaging_state
843839
* - sqrt_inv_fisher
844-
* - gradient_valid
845840
*/
846841
void update_thresholds_with_fisher_mala (
847842
arma::mat& main_effects,
@@ -858,8 +853,7 @@ void update_thresholds_with_fisher_mala (
858853
arma::mat& sqrt_inv_fisher,
859854
const double threshold_alpha,
860855
const double threshold_beta,
861-
const double initial_step_size,
862-
bool& gradient_valid
856+
const double initial_step_size
863857
) {
864858
// --- Compute current parameter vector and its gradient ---
865859
const arma::vec current_state = vectorize_thresholds (
@@ -938,7 +932,6 @@ void update_thresholds_with_fisher_mala (
938932
// --- Accept or reject proposed move ---
939933
if (std::log (R::unif_rand()) < log_accept) {
940934
main_effects = proposed_main_effects;
941-
gradient_valid = false;
942935
}
943936

944937
// --- Update step size and Fisher matrix ---
@@ -2258,8 +2251,6 @@ void update_indicator_interaction_pair_with_mala (
22582251
* - total_burnin: Total burn-in length.
22592252
* - dual_averaging_state: Dual averaging parameters (updated during burn-in).
22602253
* - sqrt_inv_fisher_pairwise: Square root of inverse Fisher information matrix (updated post-burn-in).
2261-
* - cached_interaction_gradient: Cached gradient vector (read/write).
2262-
* - gradient_valid: Flag indicating whether the cached gradient is valid (updated in-place).
22632254
*
22642255
* Modifies:
22652256
* - pairwise_effects (on accept)
@@ -2285,9 +2276,7 @@ void update_interactions_with_fisher_mala (
22852276
const int iteration,
22862277
const int total_burnin,
22872278
arma::vec& dual_averaging_state,
2288-
arma::mat& sqrt_inv_fisher_pairwise,
2289-
arma::vec& cached_interaction_gradient,
2290-
bool& gradient_valid
2279+
arma::mat& sqrt_inv_fisher_pairwise
22912280
) {
22922281
const int num_variables = pairwise_effects.n_rows;
22932282
const int num_interactions = (num_variables * (num_variables - 1)) / 2;
@@ -2305,16 +2294,12 @@ void update_interactions_with_fisher_mala (
23052294
}
23062295

23072296
// --- Compute gradient and log-posterior at current state
2308-
arma::vec current_grad;
2309-
if (gradient_valid && cached_interaction_gradient.n_elem == num_interactions) {
2310-
current_grad = cached_interaction_gradient;
2311-
} else {
2312-
current_grad = gradient_log_pseudoposterior_interactions(
2297+
arma::vec current_grad = gradient_log_pseudoposterior_interactions(
23132298
pairwise_effects, main_effects, observations, num_categories,
23142299
inclusion_indicator, is_ordinal_variable, reference_category,
23152300
interaction_scale
23162301
);
2317-
}
2302+
23182303

23192304
double current_log_post = log_pseudoposterior_interactions (
23202305
pairwise_effects, main_effects, observations, num_categories,
@@ -2398,13 +2383,6 @@ void update_interactions_with_fisher_mala (
23982383
}
23992384
}
24002385
}
2401-
cached_interaction_gradient = proposed_grad;
2402-
gradient_valid = true;
2403-
} else {
2404-
if(!gradient_valid) {
2405-
cached_interaction_gradient = current_grad;
2406-
gradient_valid = true;
2407-
}
24082386
}
24092387

24102388
// --- Step size and Fisher adaptation
@@ -2470,15 +2448,11 @@ void update_interactions_with_fisher_mala (
24702448
* - num_pairwise: Total number of candidate interactions.
24712449
* - iteration: Current MCMC iteration.
24722450
* - total_burnin: Number of warm-up iterations.
2473-
* - cached_interaction_gradient: Gradient vector reused across proposals (updated in-place).
2474-
* - gradient_valid: Flag indicating whether the gradient cache is valid (set true on exit).
24752451
*
24762452
* Modifies:
24772453
* - indicator
24782454
* - pairwise_effects (on accept)
24792455
* - residual_matrix (on accept)
2480-
* - cached_interaction_gradient (entry-by-entry if accepted)
2481-
* - gradient_valid (set to true)
24822456
*/
24832457
void update_indicator_interaction_pair_with_fisher_mala (
24842458
arma::mat& pairwise_effects,
@@ -2497,9 +2471,7 @@ void update_indicator_interaction_pair_with_fisher_mala (
24972471
const arma::mat& sqrt_inv_fisher_pairwise,
24982472
const int num_pairwise,
24992473
const int iteration,
2500-
const int total_burnin,
2501-
arma::vec& cached_interaction_gradient,
2502-
bool& gradient_valid
2474+
const int total_burnin
25032475
) {
25042476
// --- Set inverse Fisher matrix (identity during burn-in)
25052477
arma::mat inv_fisher;
@@ -2514,18 +2486,6 @@ void update_indicator_interaction_pair_with_fisher_mala (
25142486
const double scaled_step_size = step_size_pairwise / (trace_inv_fisher / num_pairwise);
25152487
const double sd = std::sqrt(scaled_step_size);
25162488

2517-
// --- Compute full gradient at current state (used in proposals)
2518-
arma::vec full_grad_current;
2519-
if (gradient_valid && cached_interaction_gradient.n_elem == num_pairwise) {
2520-
full_grad_current = cached_interaction_gradient;
2521-
} else {
2522-
full_grad_current = gradient_log_pseudoposterior_interactions(
2523-
pairwise_effects, main_effects, observations, num_categories,
2524-
indicator, is_ordinal_variable, reference_category,
2525-
interaction_scale
2526-
);
2527-
}
2528-
25292489
for (int pair_index = 0; pair_index < num_pairwise; pair_index++) {
25302490
const int interaction_index = index(pair_index, 0) - 1;
25312491
const int var1 = index(pair_index, 1);
@@ -2539,7 +2499,15 @@ void update_indicator_interaction_pair_with_fisher_mala (
25392499

25402500
const arma::rowvec fisher_row = inv_fisher.row(interaction_index);
25412501

2502+
2503+
25422504
if (proposing_addition) {
2505+
// --- Compute full gradient at current state (used in proposals)
2506+
arma::vec full_grad_current = gradient_log_pseudoposterior_interactions(
2507+
pairwise_effects, main_effects, observations, num_categories,
2508+
indicator, is_ordinal_variable, reference_category,
2509+
interaction_scale
2510+
);
25432511
// --- Propose new interaction using preconditioned Langevin step
25442512
const double drift = 0.5 * scaled_step_size * arma::dot(fisher_row, full_grad_current);
25452513
const double forward_mean = current_state + drift;
@@ -2559,10 +2527,10 @@ void update_indicator_interaction_pair_with_fisher_mala (
25592527
proposed_matrix(var2, var1) = proposed_state;
25602528

25612529
// Update only the relevant gradient component
2562-
arma::vec proposed_grad = full_grad_current; // copy
2563-
proposed_grad(interaction_index) = gradient_log_pseudoposterior_interaction_single (
2564-
var1, var2, proposed_matrix, main_effects, observations, num_categories,
2565-
is_ordinal_variable, reference_category, interaction_scale
2530+
arma::vec proposed_grad = gradient_log_pseudoposterior_interactions(
2531+
proposed_matrix, main_effects, observations, num_categories,
2532+
indicator, is_ordinal_variable, reference_category,
2533+
interaction_scale
25662534
);
25672535

25682536
const double drift = 0.5 * scaled_step_size * arma::dot(fisher_row, proposed_grad);
@@ -2593,22 +2561,8 @@ void update_indicator_interaction_pair_with_fisher_mala (
25932561
// --- Update residual matrix
25942562
residual_matrix.col(var1) += arma::conv_to<arma::vec>::from(observations.col(var2)) * delta;
25952563
residual_matrix.col(var2) += arma::conv_to<arma::vec>::from(observations.col(var1)) * delta;
2596-
2597-
// --- Maintain gradient consistency
2598-
if(new_value == 1) {
2599-
full_grad_current(interaction_index) = gradient_log_pseudoposterior_interaction_single (
2600-
var1, var2, pairwise_effects, main_effects, observations, num_categories,
2601-
is_ordinal_variable, reference_category, interaction_scale
2602-
);
2603-
} else {
2604-
full_grad_current(interaction_index) = 0.0;
2605-
}
26062564
}
26072565
}
2608-
2609-
//Updated cached gradient
2610-
cached_interaction_gradient = full_grad_current;
2611-
gradient_valid = true;
26122566
}
26132567

26142568

@@ -2708,10 +2662,7 @@ void gibbs_update_step_for_graphical_model_parameters (
27082662
arma::vec& dual_averaging_pairwise,
27092663
const double initial_step_size_pairwise,
27102664
arma::mat& sqrt_inv_fisher_pairwise,
2711-
const std::string& update_method,
2712-
arma::vec& cached_interaction_gradient,
2713-
bool& gradient_valid,
2714-
arma::vec& posterior_prob
2665+
const std::string& update_method
27152666
) {
27162667
// --- Robbins-Monro weight for adaptive Metropolis updates
27172668
const double exp_neg_log_t_rm_adaptation_rate =
@@ -2726,7 +2677,7 @@ void gibbs_update_step_for_graphical_model_parameters (
27262677
num_categories, step_size_pairwise, interaction_scale, index,
27272678
num_persons, residual_matrix, inclusion_probability, is_ordinal_variable,
27282679
reference_category, sqrt_inv_fisher_pairwise, num_pairwise,
2729-
iteration, total_burnin, cached_interaction_gradient, gradient_valid
2680+
iteration, total_burnin
27302681
);
27312682
} else if (update_method == "adaptive-mala") {
27322683
// Use standard MALA for indicator updates
@@ -2754,8 +2705,7 @@ void gibbs_update_step_for_graphical_model_parameters (
27542705
num_categories, inclusion_indicator, is_ordinal_variable,
27552706
reference_category, interaction_scale, step_size_pairwise,
27562707
initial_step_size_pairwise, iteration, total_burnin,
2757-
dual_averaging_pairwise, sqrt_inv_fisher_pairwise,
2758-
cached_interaction_gradient, gradient_valid
2708+
dual_averaging_pairwise, sqrt_inv_fisher_pairwise
27592709
);
27602710
} else if (update_method == "adaptive-mala") {
27612711
update_interactions_with_mala (
@@ -2783,7 +2733,7 @@ void gibbs_update_step_for_graphical_model_parameters (
27832733
num_obs_categories, sufficient_blume_capel, reference_category,
27842734
is_ordinal_variable, iteration, total_burnin, dual_averaging_main,
27852735
sqrt_inv_fisher_main, threshold_alpha, threshold_beta,
2786-
initial_step_size_main, gradient_valid
2736+
initial_step_size_main
27872737
);
27882738
} else {
27892739
// Metropolis updates (Blume-Capel or ordinal)
@@ -2990,8 +2940,6 @@ List run_gibbs_sampler_for_bgm (
29902940
dual_averaging_main[0] = std::log (step_size_main);
29912941
dual_averaging_pairwise[0] = std::log (step_size_pairwise);
29922942
}
2993-
arma::vec cached_interaction_gradient; // will hold a cached gradient vector
2994-
bool gradient_valid = false; // indicates whether the cache is valid
29952943
arma::vec posterior_prob(num_pairwise);
29962944

29972945
// --- Set up total number of iterations (burn-in + sampling)
@@ -3041,8 +2989,7 @@ List run_gibbs_sampler_for_bgm (
30412989
reference_category, edge_selection, step_size_main, iteration,
30422990
dual_averaging_main, total_burnin, initial_step_size_main,
30432991
sqrt_inv_fisher_main, step_size_pairwise, dual_averaging_pairwise,
3044-
initial_step_size_pairwise, sqrt_inv_fisher_pairwise, update_method,
3045-
cached_interaction_gradient, gradient_valid, posterior_prob
2992+
initial_step_size_pairwise, sqrt_inv_fisher_pairwise, update_method
30462993
);
30472994

30482995
// --- Update edge probabilities under the prior (if edge selection is active)

0 commit comments

Comments
 (0)