Skip to content

Commit ed304a3

Browse files
committed
MALA proposal mean=base+drift
1 parent 758eb98 commit ed304a3

File tree

2 files changed

+36
-35
lines changed

2 files changed

+36
-35
lines changed

src/BVS.cpp

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,9 @@ void BVS_Sampler::sampleGamma(
370370
// double c = std::exp(a);
371371

372372
// Update proposal ratio with beta part
373-
logProposalRatio -= MALAbetas(proposedBeta, betas_, updateIdx0, singleIdx_k, componentUpdateIdx,
373+
logProposalRatio -= MALAbetas(proposedBeta, betas_, updateIdx0, componentUpdateIdx,
374374
datTheta, datProportion, weibullS, weibullLambda, kappa_, tauSq_[componentUpdateIdx], sigmaMH_beta, dataclass);
375-
logProposalRatio += MALAlogPbetas(betas_, proposedBeta, updateIdx0, singleIdx_k, componentUpdateIdx,
375+
logProposalRatio += MALAlogPbetas(betas_, proposedBeta, updateIdx0, componentUpdateIdx,
376376
datTheta, datProportion, kappa_, tauSq_[componentUpdateIdx], sigmaMH_beta, dataclass);
377377

378378
} else {//if( arma::any(gammas_(updateIdx,singleIdx_k)) ) {
@@ -658,9 +658,9 @@ void BVS_Sampler::sampleEta(
658658
// Update proposal ratio with beta part
659659
// logProposalRatio -= logPDFNormal(proposedZeta(1 + updateIdx0, singleIdx_k), m, Sigma);
660660
// logProposalRatio += logPDFNormal(zetas_(1 + updateIdx0, singleIdx_k), m_mutant, Sigma_mutant);// TODO: use proposedZeta to repeat the above steps (wrap into a func) to obtain m_mutant & Sigma_mutant
661-
logProposalRatio -= MALAzetas(proposedZeta, zetas_, updateIdx0, singleIdx_k, componentUpdateIdx,
661+
logProposalRatio -= MALAzetas(proposedZeta, zetas_, updateIdx0, componentUpdateIdx,
662662
datTheta, weibullS, weibullLambda, kappa_, wSq_[componentUpdateIdx], sigmaMH_zeta, dataclass);
663-
logProposalRatio += MALAlogPzetas(zetas_, proposedZeta, updateIdx0, singleIdx_k, componentUpdateIdx,
663+
logProposalRatio += MALAlogPzetas(zetas_, proposedZeta, updateIdx0, componentUpdateIdx,
664664
datTheta, weibullS, weibullLambda, kappa_, wSq_[componentUpdateIdx], sigmaMH_zeta, dataclass);
665665

666666
} else {//if( arma::any(etas_(updateIdx,singleIdx_k)) ) {
@@ -1102,7 +1102,6 @@ double BVS_Sampler::MALAbetas(
11021102
arma::mat& proposedBeta,
11031103
const arma::mat& betas_,
11041104
const arma::uvec& updateIdx0,
1105-
const arma::uvec& singleIdx_k,
11061105
unsigned int componentUpdateIdx,
11071106

11081107
const arma::vec& datTheta,
@@ -1117,6 +1116,7 @@ double BVS_Sampler::MALAbetas(
11171116
// dimensions
11181117
unsigned int n = dataclass.datX.n_rows;
11191118
unsigned int L = dataclass.datX.n_slices;
1119+
arma::uvec singleIdx_k = { componentUpdateIdx };
11201120

11211121
// Precompute a and powers
11221122
arma::vec a = dataclass.datTime / weibullLambda.col(componentUpdateIdx); // n-vector
@@ -1138,15 +1138,13 @@ double BVS_Sampler::MALAbetas(
11381138
% weibullS.col(l);
11391139
}
11401140

1141-
// Per-observation weights w_i (likelihood-only score part)
1142-
arma::vec w_i(n);
1143-
11441141
// Stabilize division s_l / S
11451142
const double eps_div = 1e-12;
11461143
arma::vec S_safe = S + eps_div;
11471144

1145+
// Per-observation weights w_i (likelihood-only score part)
11481146
// First term: −δ_i * (s_il / S_i) * (1 − z_il^κ)
1149-
w_i = - dataclass.datEvent
1147+
arma::vec w_i = - dataclass.datEvent
11501148
% (s_l / S_safe)
11511149
% (1.0 - a_power_kappa);
11521150

@@ -1185,17 +1183,18 @@ double BVS_Sampler::MALAbetas(
11851183
gradient_betaK -= betas_(1 + updateIdx0, singleIdx_k) / tauSqK;
11861184

11871185
// MALA mean and covariance
1188-
// double eps = sigmaMH_beta;
1189-
arma::vec m = 0.5 * eps * eps * (M * gradient_betaK);
1190-
arma::mat Sigma = eps * eps * M;
1186+
// Posterior gradient at current (forward direction) already correct
1187+
arma::vec m = 0.5 * eps * eps * (M * gradient_betaK);
1188+
arma::mat Sigma = eps * eps * M;
11911189

1192-
// Draw MALA proposal increment u ~ N(m, Sigma)
1190+
// Propose: u ~ N(m, Sigma), beta_new = beta_old + u
11931191
arma::vec u = randMvNormal(m, Sigma);
1194-
1195-
// Update proposed parameters on the selected coordinates
11961192
proposedBeta(1 + updateIdx0, singleIdx_k) += u;
11971193

1198-
double logP = logPDFNormal(proposedBeta(1 + updateIdx0, singleIdx_k), m, Sigma);
1194+
// Forward log-proposal density q(proposed | current)
1195+
arma::vec mu_fwd = betas_(1 + updateIdx0, singleIdx_k) + m;
1196+
double logP = logPDFNormal(proposedBeta(1 + updateIdx0, singleIdx_k), mu_fwd, Sigma);
1197+
11991198

12001199
return logP;
12011200
}
@@ -1206,7 +1205,6 @@ double BVS_Sampler::MALAlogPbetas(
12061205
const arma::mat& betas_,
12071206
const arma::mat& proposedBeta,
12081207
const arma::uvec& updateIdx0,
1209-
const arma::uvec& singleIdx_k,
12101208
unsigned int componentUpdateIdx,
12111209

12121210
const arma::vec& datTheta,
@@ -1220,6 +1218,7 @@ double BVS_Sampler::MALAlogPbetas(
12201218
unsigned int n = dataclass.datX.n_rows;
12211219
unsigned int p = dataclass.datX.n_cols;
12221220
unsigned int L = dataclass.datX.n_slices;
1221+
arma::uvec singleIdx_k = { componentUpdateIdx };
12231222

12241223
// Precompute a and powers
12251224
arma::vec a; // n-vector
@@ -1295,17 +1294,19 @@ double BVS_Sampler::MALAlogPbetas(
12951294
arma::inv(M, Delta_betaK, arma::inv_opts::allow_approx);
12961295
}
12971296

1298-
// Full posterior gradient (sum-likelihood − prior): d-vector
1299-
arma::vec gradient_betaK = arma::sum(gradient_betaK_i0, 0).t(); // sum over i
1300-
gradient_betaK -= betas_(1 + updateIdx0, singleIdx_k) / tauSqK;
1301-
13021297
// MALA mean and covariance
1303-
// double eps = sigmaMH_beta;
1304-
arma::vec m = 0.5 * eps * eps * (M * gradient_betaK);
1305-
arma::mat Sigma = eps * eps * M;
1298+
// Prior gradient at proposed (reverse direction)
1299+
arma::vec gradient_betaK = arma::sum(gradient_betaK_i0, 0).t();
1300+
gradient_betaK -= proposedBeta(1 + updateIdx0, singleIdx_k) / tauSqK;
1301+
1302+
// MALA mean at proposed
1303+
arma::vec m = 0.5 * eps * eps * (M * gradient_betaK);
1304+
arma::mat Sigma = eps * eps * M;
1305+
arma::vec mu_rev = proposedBeta(1 + updateIdx0, singleIdx_k) + m;
1306+
1307+
// Reverse log-proposal density q(current | proposed)
1308+
double logP = logPDFNormal(betas_(1 + updateIdx0, singleIdx_k), mu_rev, Sigma);
13061309

1307-
// Compute proposal density log q(betas| proposedBeta)
1308-
double logP = logPDFNormal(betas_(1 + updateIdx0, singleIdx_k), m, Sigma);
13091310

13101311
return logP;
13111312
}
@@ -1316,7 +1317,6 @@ double BVS_Sampler::MALAzetas(
13161317
arma::mat& proposedZeta,
13171318
const arma::mat& zetas_,
13181319
const arma::uvec& updateIdx0,
1319-
const arma::uvec& singleIdx_k,
13201320
unsigned int componentUpdateIdx,
13211321

13221322
const arma::vec& datTheta,
@@ -1332,6 +1332,7 @@ double BVS_Sampler::MALAzetas(
13321332
unsigned int n = dataclass.datX.n_rows;
13331333
unsigned int p = dataclass.datX.n_cols;
13341334
unsigned int L = dataclass.datX.n_slices;
1335+
arma::uvec singleIdx_k = { componentUpdateIdx };
13351336

13361337
// Compute alphas: n x L
13371338
arma::mat alphas = arma::zeros<arma::mat>(n, L);
@@ -1418,7 +1419,9 @@ double BVS_Sampler::MALAzetas(
14181419
// Update proposed parameters on the selected coordinates
14191420
proposedZeta(1 + updateIdx0, singleIdx_k) += u;
14201421

1421-
double logP = logPDFNormal(proposedZeta(1 + updateIdx0, singleIdx_k), m, Sigma);
1422+
// forward log-density: mean = current + drift
1423+
arma::vec mu_fwd = zetas_(1 + updateIdx0, singleIdx_k) + m;
1424+
double logP = logPDFNormal(proposedZeta(1 + updateIdx0, singleIdx_k), mu_fwd, Sigma);
14221425

14231426
return logP;
14241427
}
@@ -1428,7 +1431,6 @@ double BVS_Sampler::MALAlogPzetas(
14281431
const arma::mat& zetas_,
14291432
const arma::mat& proposedZeta,
14301433
const arma::uvec& updateIdx0,
1431-
const arma::uvec& singleIdx_k,
14321434
unsigned int componentUpdateIdx,
14331435

14341436
const arma::vec& datTheta,
@@ -1444,6 +1446,7 @@ double BVS_Sampler::MALAlogPzetas(
14441446
unsigned int n = dataclass.datX.n_rows;
14451447
unsigned int p = dataclass.datX.n_cols;
14461448
unsigned int L = dataclass.datX.n_slices;
1449+
arma::uvec singleIdx_k = { componentUpdateIdx };
14471450

14481451
// Compute alphas: n x L
14491452
arma::mat alphas = arma::zeros<arma::mat>(n, L);
@@ -1524,8 +1527,10 @@ double BVS_Sampler::MALAlogPzetas(
15241527
arma::vec m = 0.5 * eps * eps * (M * gradient_zetaK);
15251528
arma::mat Sigma = eps * eps * M;
15261529

1527-
// Compute proposal density log q(zetas| proposedZeta)
1528-
double logP = logPDFNormal(zetas_(1 + updateIdx0, singleIdx_k), m, Sigma);
1530+
arma::vec mu_rev = proposedZeta(1 + updateIdx0, singleIdx_k) + m;
1531+
1532+
// reverse log-density: q(current | proposed)
1533+
double logP = logPDFNormal(zetas_(1 + updateIdx0, singleIdx_k), mu_rev, Sigma);
15291534

15301535
return logP;
15311536
}

src/BVS.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ class BVS_Sampler
216216
arma::mat& proposedBeta,
217217
const arma::mat& betas_,
218218
const arma::uvec& updateIdx0,
219-
const arma::uvec& singleIdx_k,
220219
unsigned int componentUpdateIdx,
221220

222221
const arma::vec& datTheta,
@@ -233,7 +232,6 @@ class BVS_Sampler
233232
const arma::mat& betas_,
234233
const arma::mat& proposedBeta,
235234
const arma::uvec& updateIdx0,
236-
const arma::uvec& singleIdx_k,
237235
unsigned int componentUpdateIdx,
238236

239237
const arma::vec& datTheta,
@@ -248,7 +246,6 @@ class BVS_Sampler
248246
arma::mat& proposedZeta,
249247
const arma::mat& zetas_,
250248
const arma::uvec& updateIdx0,
251-
const arma::uvec& singleIdx_k,
252249
unsigned int componentUpdateIdx,
253250

254251
const arma::vec& datTheta,
@@ -264,7 +261,6 @@ class BVS_Sampler
264261
const arma::mat& zetas_,
265262
const arma::mat& proposedZeta,
266263
const arma::uvec& updateIdx0,
267-
const arma::uvec& singleIdx_k,
268264
unsigned int componentUpdateIdx,
269265

270266
const arma::vec& datTheta,

0 commit comments

Comments
 (0)