@@ -832,16 +832,11 @@ inline void update_fisher_preconditioner (
832832 * The step size is trace-scaled to match the average proposal magnitude across
833833 * dimensions, following Proposition 1 from Titsias (2023).
834834 *
835- * If the proposal is accepted, the main_effects matrix is updated in-place and
836- * the gradient cache for the interaction update is invalidated by setting
837- * `gradient_valid = false`.
838- *
839835 * Modifies:
840836 * - main_effects
841837 * - step_size
842838 * - dual_averaging_state
843839 * - sqrt_inv_fisher
844- * - gradient_valid
845840 */
846841void update_thresholds_with_fisher_mala (
847842 arma::mat& main_effects,
@@ -858,8 +853,7 @@ void update_thresholds_with_fisher_mala (
858853 arma::mat& sqrt_inv_fisher,
859854 const double threshold_alpha,
860855 const double threshold_beta,
861- const double initial_step_size,
862- bool & gradient_valid
856+ const double initial_step_size
863857) {
864858 // --- Compute current parameter vector and its gradient ---
865859 const arma::vec current_state = vectorize_thresholds (
@@ -938,7 +932,6 @@ void update_thresholds_with_fisher_mala (
938932 // --- Accept or reject proposed move ---
939933 if (std::log (R::unif_rand ()) < log_accept) {
940934 main_effects = proposed_main_effects;
941- gradient_valid = false ;
942935 }
943936
944937 // --- Update step size and Fisher matrix ---
@@ -2258,8 +2251,6 @@ void update_indicator_interaction_pair_with_mala (
22582251 * - total_burnin: Total burn-in length.
22592252 * - dual_averaging_state: Dual averaging parameters (updated during burn-in).
22602253 * - sqrt_inv_fisher_pairwise: Square root of inverse Fisher information matrix (updated post-burn-in).
2261- * - cached_interaction_gradient: Cached gradient vector (read/write).
2262- * - gradient_valid: Flag indicating whether the cached gradient is valid (updated in-place).
22632254 *
22642255 * Modifies:
22652256 * - pairwise_effects (on accept)
@@ -2285,9 +2276,7 @@ void update_interactions_with_fisher_mala (
22852276 const int iteration,
22862277 const int total_burnin,
22872278 arma::vec& dual_averaging_state,
2288- arma::mat& sqrt_inv_fisher_pairwise,
2289- arma::vec& cached_interaction_gradient,
2290- bool & gradient_valid
2279+ arma::mat& sqrt_inv_fisher_pairwise
22912280) {
22922281 const int num_variables = pairwise_effects.n_rows ;
22932282 const int num_interactions = (num_variables * (num_variables - 1 )) / 2 ;
@@ -2305,16 +2294,12 @@ void update_interactions_with_fisher_mala (
23052294 }
23062295
23072296 // --- Compute gradient and log-posterior at current state
2308- arma::vec current_grad;
2309- if (gradient_valid && cached_interaction_gradient.n_elem == num_interactions) {
2310- current_grad = cached_interaction_gradient;
2311- } else {
2312- current_grad = gradient_log_pseudoposterior_interactions (
2297+ arma::vec current_grad = gradient_log_pseudoposterior_interactions (
23132298 pairwise_effects, main_effects, observations, num_categories,
23142299 inclusion_indicator, is_ordinal_variable, reference_category,
23152300 interaction_scale
23162301 );
2317- }
2302+
23182303
23192304 double current_log_post = log_pseudoposterior_interactions (
23202305 pairwise_effects, main_effects, observations, num_categories,
@@ -2398,13 +2383,6 @@ void update_interactions_with_fisher_mala (
23982383 }
23992384 }
24002385 }
2401- cached_interaction_gradient = proposed_grad;
2402- gradient_valid = true ;
2403- } else {
2404- if (!gradient_valid) {
2405- cached_interaction_gradient = current_grad;
2406- gradient_valid = true ;
2407- }
24082386 }
24092387
24102388 // --- Step size and Fisher adaptation
@@ -2470,15 +2448,11 @@ void update_interactions_with_fisher_mala (
24702448 * - num_pairwise: Total number of candidate interactions.
24712449 * - iteration: Current MCMC iteration.
24722450 * - total_burnin: Number of warm-up iterations.
2473- * - cached_interaction_gradient: Gradient vector reused across proposals (updated in-place).
2474- * - gradient_valid: Flag indicating whether the gradient cache is valid (set true on exit).
24752451 *
24762452 * Modifies:
24772453 * - indicator
24782454 * - pairwise_effects (on accept)
24792455 * - residual_matrix (on accept)
2480- * - cached_interaction_gradient (entry-by-entry if accepted)
2481- * - gradient_valid (set to true)
24822456 */
24832457void update_indicator_interaction_pair_with_fisher_mala (
24842458 arma::mat& pairwise_effects,
@@ -2497,9 +2471,7 @@ void update_indicator_interaction_pair_with_fisher_mala (
24972471 const arma::mat& sqrt_inv_fisher_pairwise,
24982472 const int num_pairwise,
24992473 const int iteration,
2500- const int total_burnin,
2501- arma::vec& cached_interaction_gradient,
2502- bool & gradient_valid
2474+ const int total_burnin
25032475) {
25042476 // --- Set inverse Fisher matrix (identity during burn-in)
25052477 arma::mat inv_fisher;
@@ -2514,18 +2486,6 @@ void update_indicator_interaction_pair_with_fisher_mala (
25142486 const double scaled_step_size = step_size_pairwise / (trace_inv_fisher / num_pairwise);
25152487 const double sd = std::sqrt (scaled_step_size);
25162488
2517- // --- Compute full gradient at current state (used in proposals)
2518- arma::vec full_grad_current;
2519- if (gradient_valid && cached_interaction_gradient.n_elem == num_pairwise) {
2520- full_grad_current = cached_interaction_gradient;
2521- } else {
2522- full_grad_current = gradient_log_pseudoposterior_interactions (
2523- pairwise_effects, main_effects, observations, num_categories,
2524- indicator, is_ordinal_variable, reference_category,
2525- interaction_scale
2526- );
2527- }
2528-
25292489 for (int pair_index = 0 ; pair_index < num_pairwise; pair_index++) {
25302490 const int interaction_index = index (pair_index, 0 ) - 1 ;
25312491 const int var1 = index (pair_index, 1 );
@@ -2539,7 +2499,15 @@ void update_indicator_interaction_pair_with_fisher_mala (
25392499
25402500 const arma::rowvec fisher_row = inv_fisher.row (interaction_index);
25412501
2502+
2503+
25422504 if (proposing_addition) {
2505+ // --- Compute full gradient at current state (used in proposals)
2506+ arma::vec full_grad_current = gradient_log_pseudoposterior_interactions (
2507+ pairwise_effects, main_effects, observations, num_categories,
2508+ indicator, is_ordinal_variable, reference_category,
2509+ interaction_scale
2510+ );
25432511 // --- Propose new interaction using preconditioned Langevin step
25442512 const double drift = 0.5 * scaled_step_size * arma::dot (fisher_row, full_grad_current);
25452513 const double forward_mean = current_state + drift;
@@ -2559,10 +2527,10 @@ void update_indicator_interaction_pair_with_fisher_mala (
25592527 proposed_matrix (var2, var1) = proposed_state;
25602528
25612529 // Update only the relevant gradient component
2562- arma::vec proposed_grad = full_grad_current; // copy
2563- proposed_grad (interaction_index) = gradient_log_pseudoposterior_interaction_single (
2564- var1, var2, proposed_matrix, main_effects, observations, num_categories ,
2565- is_ordinal_variable, reference_category, interaction_scale
2530+ arma::vec proposed_grad = gradient_log_pseudoposterior_interactions (
2531+ proposed_matrix, main_effects, observations, num_categories,
2532+ indicator, is_ordinal_variable, reference_category ,
2533+ interaction_scale
25662534 );
25672535
25682536 const double drift = 0.5 * scaled_step_size * arma::dot (fisher_row, proposed_grad);
@@ -2593,22 +2561,8 @@ void update_indicator_interaction_pair_with_fisher_mala (
25932561 // --- Update residual matrix
25942562 residual_matrix.col (var1) += arma::conv_to<arma::vec>::from (observations.col (var2)) * delta;
25952563 residual_matrix.col (var2) += arma::conv_to<arma::vec>::from (observations.col (var1)) * delta;
2596-
2597- // --- Maintain gradient consistency
2598- if (new_value == 1 ) {
2599- full_grad_current (interaction_index) = gradient_log_pseudoposterior_interaction_single (
2600- var1, var2, pairwise_effects, main_effects, observations, num_categories,
2601- is_ordinal_variable, reference_category, interaction_scale
2602- );
2603- } else {
2604- full_grad_current (interaction_index) = 0.0 ;
2605- }
26062564 }
26072565 }
2608-
2609- // Updated cached gradient
2610- cached_interaction_gradient = full_grad_current;
2611- gradient_valid = true ;
26122566}
26132567
26142568
@@ -2708,10 +2662,7 @@ void gibbs_update_step_for_graphical_model_parameters (
27082662 arma::vec& dual_averaging_pairwise,
27092663 const double initial_step_size_pairwise,
27102664 arma::mat& sqrt_inv_fisher_pairwise,
2711- const std::string& update_method,
2712- arma::vec& cached_interaction_gradient,
2713- bool & gradient_valid,
2714- arma::vec& posterior_prob
2665+ const std::string& update_method
27152666) {
27162667 // --- Robbins-Monro weight for adaptive Metropolis updates
27172668 const double exp_neg_log_t_rm_adaptation_rate =
@@ -2726,7 +2677,7 @@ void gibbs_update_step_for_graphical_model_parameters (
27262677 num_categories, step_size_pairwise, interaction_scale, index,
27272678 num_persons, residual_matrix, inclusion_probability, is_ordinal_variable,
27282679 reference_category, sqrt_inv_fisher_pairwise, num_pairwise,
2729- iteration, total_burnin, cached_interaction_gradient, gradient_valid
2680+ iteration, total_burnin
27302681 );
27312682 } else if (update_method == " adaptive-mala" ) {
27322683 // Use standard MALA for indicator updates
@@ -2754,8 +2705,7 @@ void gibbs_update_step_for_graphical_model_parameters (
27542705 num_categories, inclusion_indicator, is_ordinal_variable,
27552706 reference_category, interaction_scale, step_size_pairwise,
27562707 initial_step_size_pairwise, iteration, total_burnin,
2757- dual_averaging_pairwise, sqrt_inv_fisher_pairwise,
2758- cached_interaction_gradient, gradient_valid
2708+ dual_averaging_pairwise, sqrt_inv_fisher_pairwise
27592709 );
27602710 } else if (update_method == " adaptive-mala" ) {
27612711 update_interactions_with_mala (
@@ -2783,7 +2733,7 @@ void gibbs_update_step_for_graphical_model_parameters (
27832733 num_obs_categories, sufficient_blume_capel, reference_category,
27842734 is_ordinal_variable, iteration, total_burnin, dual_averaging_main,
27852735 sqrt_inv_fisher_main, threshold_alpha, threshold_beta,
2786- initial_step_size_main, gradient_valid
2736+ initial_step_size_main
27872737 );
27882738 } else {
27892739 // Metropolis updates (Blume-Capel or ordinal)
@@ -2990,8 +2940,6 @@ List run_gibbs_sampler_for_bgm (
29902940 dual_averaging_main[0 ] = std::log (step_size_main);
29912941 dual_averaging_pairwise[0 ] = std::log (step_size_pairwise);
29922942 }
2993- arma::vec cached_interaction_gradient; // will hold a cached gradient vector
2994- bool gradient_valid = false ; // indicates whether the cache is valid
29952943 arma::vec posterior_prob (num_pairwise);
29962944
29972945 // --- Set up total number of iterations (burn-in + sampling)
@@ -3041,8 +2989,7 @@ List run_gibbs_sampler_for_bgm (
30412989 reference_category, edge_selection, step_size_main, iteration,
30422990 dual_averaging_main, total_burnin, initial_step_size_main,
30432991 sqrt_inv_fisher_main, step_size_pairwise, dual_averaging_pairwise,
3044- initial_step_size_pairwise, sqrt_inv_fisher_pairwise, update_method,
3045- cached_interaction_gradient, gradient_valid, posterior_prob
2992+ initial_step_size_pairwise, sqrt_inv_fisher_pairwise, update_method
30462993 );
30472994
30482995 // --- Update edge probabilities under the prior (if edge selection is active)
0 commit comments