Skip to content

Commit b3ec172

Browse files
committed
add AIS proposal beta
1 parent 77ee549 commit b3ec172

File tree

2 files changed

+228
-37
lines changed

2 files changed

+228
-37
lines changed

src/BVS.cpp

Lines changed: 195 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -374,18 +374,184 @@ void BVS_Sampler::sampleGamma(
374374
proposedBeta(1 + off_prop, singleIdx_k).zeros();
375375
}
376376

377+
double logLikelihoodRatio = 0.;
378+
377379
if (rw_mh != "symmetric")
378380
{
379381
// double c = std::exp(a);
380382

381383
// Update proposal ratio with beta part
382-
if (updateIdx0.n_elem > 0) {
383-
logProposalRatio -= MALAbetas(proposedBeta, betas_, updateIdx0, componentUpdateIdx,
384-
datTheta, datProportion, weibullS, weibullLambda, kappa_, tauSq_[componentUpdateIdx], sigmaMH_beta, dataclass);
385-
}
386-
if (updateIdx0_rev.n_elem > 0) {
387-
logProposalRatio += MALAlogPbetas(betas_, proposedBeta, updateIdx0, componentUpdateIdx,
388-
datTheta, datProportion, kappa_, tauSq_[componentUpdateIdx], sigmaMH_beta, dataclass);
384+
if (rw_mh == "mala") {
385+
if (updateIdx0.n_elem > 0) {
386+
logProposalRatio -= MALAbetas(proposedBeta, betas_, updateIdx0, componentUpdateIdx,
387+
datTheta, datProportion, weibullS, weibullLambda, kappa_, tauSq_[componentUpdateIdx], sigmaMH_beta, dataclass);
388+
}
389+
if (updateIdx0_rev.n_elem > 0) {
390+
logProposalRatio += MALAlogPbetas(betas_, proposedBeta, updateIdx0, componentUpdateIdx,
391+
datTheta, datProportion, kappa_, tauSq_[componentUpdateIdx], sigmaMH_beta, dataclass);
392+
}
393+
} else { // AIS sampling for marginal likelihood ratio
394+
395+
// Annealed Importance Sampling (AIS)
396+
397+
// 1. Define Annealing Schedule
398+
unsigned int K = 101; // K+1 temperatures
399+
unsigned int M = 10; // candidate samples
400+
401+
if (updateIdx0.n_elem > 0) {
402+
// 2. Initialize Samples
403+
unsigned int J = updateIdx0.n_elem;
404+
arma::mat betaAIS = arma::zeros<arma::mat>(J, M);
405+
for(unsigned int m=0; m<M; ++m) {
406+
arma::vec u = Rcpp::rnorm(J, 0.0, tauSq_[componentUpdateIdx]);
407+
betaAIS.col(m) = u;
408+
}
409+
410+
// 3. Annealing Procedure
411+
double likelihood_marginal = 0;
412+
for(unsigned int m=0; m<M; ++m){
413+
double log_w = 0.;
414+
for(unsigned int k=0; k<K; ++k){
415+
arma::vec betaAIS_proposal = betaAIS.col(m);
416+
arma::vec u = Rcpp::rnorm(J, 0.0, 1.0);
417+
betaAIS_proposal += u;
418+
419+
double logAccProb_AIS = 0.;
420+
double t_k = (double)(k+1) / (double)(K);
421+
logAccProb_AIS += (1.-t_k)*logPDFNormal(betaAIS_proposal, tauSq_[componentUpdateIdx]);
422+
423+
arma::mat betas_tmpAIS = betas_;
424+
betas_tmpAIS(updateIdx0, singleIdx_k) = betaAIS_proposal;
425+
logAccProb_AIS += t_k*logPbetaK(componentUpdateIdx,
426+
betas_tmpAIS,
427+
tauSq_[componentUpdateIdx],
428+
kappa_,
429+
datTheta,
430+
datProportion,
431+
dataclass
432+
);
433+
434+
double logPbetaK_AIS0 = logPDFNormal(betaAIS.col(m), tauSq_[componentUpdateIdx]);
435+
logAccProb_AIS -= (1.-t_k) * logPbetaK_AIS0;
436+
betas_tmpAIS(updateIdx0, singleIdx_k) = betaAIS.col(m);
437+
double logPbetaK_tmpAIS = logPbetaK(componentUpdateIdx,
438+
betas_tmpAIS,
439+
tauSq_[componentUpdateIdx],
440+
kappa_,
441+
datTheta,
442+
datProportion,
443+
dataclass
444+
);
445+
logAccProb_AIS -= t_k * logPbetaK_tmpAIS;
446+
447+
// no need AIS MH's proposal ratio due to symmetric RW proposal
448+
449+
double t_k0 = (double)(k) / (double)(K);
450+
if( std::log(R::runif(0,1)) < logAccProb_AIS ){
451+
betaAIS.col(m) = betaAIS_proposal;
452+
453+
// 4. Compute Importance Weights
454+
log_w += (t_k - t_k0) * (logPDFNormal(betaAIS.col(m), tauSq_[componentUpdateIdx]) - logPbetaK_tmpAIS);
455+
} else {
456+
logPbetaK_tmpAIS = logPbetaK(componentUpdateIdx,
457+
betas_tmpAIS,
458+
tauSq_[componentUpdateIdx],
459+
kappa_,
460+
datTheta,
461+
datProportion,
462+
dataclass
463+
);
464+
// 4. Compute Importance Weights
465+
log_w += (t_k - t_k0) * (logPbetaK_AIS0 - logPbetaK_tmpAIS);
466+
}
467+
}
468+
469+
//5. Estimate Marginal Likelihood
470+
likelihood_marginal += std::exp(log_w);
471+
}
472+
473+
logLikelihoodRatio -= std::log(likelihood_marginal);
474+
}
475+
476+
//------------------------------------------------------
477+
// To the same as above corresponding to proposedGamma
478+
//------------------------------------------------------
479+
480+
if (updateIdx0_rev.n_elem > 0) {
481+
// 2. Initialize Samples
482+
unsigned int J = updateIdx0_rev.n_elem;
483+
arma::mat betaAIS = arma::zeros<arma::mat>(J, M);
484+
for(unsigned int m=0; m<M; ++m) {
485+
arma::vec u = Rcpp::rnorm(J, 0.0, tauSq_[componentUpdateIdx]);
486+
betaAIS.col(m) = u;
487+
}
488+
489+
// 3. Annealing Procedure
490+
double likelihood_marginal_rev = 0;
491+
for(unsigned int m=0; m<M; ++m){
492+
double log_w = 0.;
493+
for(unsigned int k=0; k<K; ++k){
494+
arma::vec betaAIS_proposal = betaAIS.col(m);
495+
arma::vec u = Rcpp::rnorm(J, 0.0, 1.0);
496+
betaAIS_proposal += u;
497+
498+
double logAccProb_AIS = 0.;
499+
double t_k = (double)(k+1) / (double)(K);
500+
logAccProb_AIS += (1.-t_k)*logPDFNormal(betaAIS_proposal, tauSq_[componentUpdateIdx]);
501+
502+
arma::mat betas_tmpAIS = betas_;
503+
betas_tmpAIS(updateIdx0_rev, singleIdx_k) = betaAIS_proposal;
504+
logAccProb_AIS += t_k*logPbetaK(componentUpdateIdx,
505+
betas_tmpAIS,
506+
tauSq_[componentUpdateIdx],
507+
kappa_,
508+
datTheta,
509+
datProportion,
510+
dataclass
511+
);
512+
513+
double logPbetaK_AIS0 = logPDFNormal(betaAIS.col(m), tauSq_[componentUpdateIdx]);
514+
logAccProb_AIS -= (1.-t_k) * logPbetaK_AIS0;
515+
betas_tmpAIS(updateIdx0_rev, singleIdx_k) = betaAIS.col(m);
516+
double logPbetaK_tmpAIS = logPbetaK(componentUpdateIdx,
517+
betas_tmpAIS,
518+
tauSq_[componentUpdateIdx],
519+
kappa_,
520+
datTheta,
521+
datProportion,
522+
dataclass
523+
);
524+
logAccProb_AIS -= t_k * logPbetaK_tmpAIS;
525+
526+
// no need AIS MH's proposal ratio due to symmetric RW proposal
527+
528+
double t_k0 = (double)(k) / (double)(K);
529+
if( std::log(R::runif(0,1)) < logAccProb_AIS ){
530+
betaAIS.col(m) = betaAIS_proposal;
531+
532+
// 4. Compute Importance Weights
533+
log_w += (t_k - t_k0) * (logPDFNormal(betaAIS.col(m), tauSq_[componentUpdateIdx]) - logPbetaK_tmpAIS);
534+
} else {
535+
logPbetaK_tmpAIS = logPbetaK(componentUpdateIdx,
536+
betas_tmpAIS,
537+
tauSq_[componentUpdateIdx],
538+
kappa_,
539+
datTheta,
540+
datProportion,
541+
dataclass
542+
);
543+
// 4. Compute Importance Weights
544+
log_w += (t_k - t_k0) * (logPbetaK_AIS0 - logPbetaK_tmpAIS);
545+
}
546+
}
547+
548+
//5. Estimate Marginal Likelihood
549+
likelihood_marginal_rev += std::exp(log_w);
550+
}
551+
552+
logLikelihoodRatio += std::log(likelihood_marginal_rev);
553+
}
554+
389555
}
390556

391557
} else {
@@ -424,21 +590,23 @@ void BVS_Sampler::sampleGamma(
424590

425591
// prior ratio of beta
426592
double logPriorBetaRatio = 0.;
427-
logPriorBetaRatio += logPDFNormal(proposedBeta(1+updateIdx,singleIdx_k), tauSq_[componentUpdateIdx]);
428-
logPriorBetaRatio -= logPDFNormal(betas_(1+updateIdx,singleIdx_k), tauSq_[componentUpdateIdx]);
429-
430593
// compute logLikelihoodRatio, i.e. proposedLikelihood - log_likelihood
431594
arma::vec proposedLikelihood = log_likelihood_;
432-
// loglikelihood( xi_, zetas_, betas_, kappa_, proportion_model, dataclass, log_likelihood_ );
433-
loglikelihood( xi_, zetas_, proposedBeta, kappa_, proportion_model, dataclass, proposedLikelihood );
595+
if (rw_mh != "ais") {
596+
logPriorBetaRatio += logPDFNormal(proposedBeta(1+updateIdx,singleIdx_k), tauSq_[componentUpdateIdx]);
597+
logPriorBetaRatio -= logPDFNormal(betas_(1+updateIdx,singleIdx_k), tauSq_[componentUpdateIdx]);
434598

435-
double logLikelihoodRatio = arma::sum(proposedLikelihood - log_likelihood_);
599+
// loglikelihood( xi_, zetas_, betas_, kappa_, proportion_model, dataclass, log_likelihood_ );
600+
loglikelihood( xi_, zetas_, proposedBeta, kappa_, proportion_model, dataclass, proposedLikelihood );
601+
602+
logLikelihoodRatio = arma::sum(proposedLikelihood - log_likelihood_);
603+
}
436604

437605
// Here we need always compute the proposal and original ratios, in particular the likelihood, since betas are updated
438606
//logProposalGammaRatio = arma::accu(proposedGammaPrior - logP_gamma);
439607
double logAccProb = logLikelihoodRatio +
440608
logPriorGammaRatio +
441-
logPriorBetaRatio +
609+
//logPriorBetaRatio +
442610
logProposalRatio;
443611
/*
444612
// std::cout << "...debug logAccProb=" << logAccProb <<
@@ -707,15 +875,18 @@ void BVS_Sampler::sampleEta(
707875
// Update proposal ratio with beta part
708876
// logProposalRatio -= logPDFNormal(proposedZeta(1 + updateIdx0, singleIdx_k), m, Sigma);
709877
// logProposalRatio += logPDFNormal(zetas_(1 + updateIdx0, singleIdx_k), m_mutant, Sigma_mutant);// TODO: use proposedZeta to repeat the above steps (wrap into a func) to obtain m_mutant & Sigma_mutant
710-
if (updateIdx0.n_elem > 0) {
711-
logProposalRatio -= MALAzetas(proposedZeta, zetas_, updateIdx0, componentUpdateIdx,
712-
datTheta, weibullS, weibullLambda, kappa_, wSq_[componentUpdateIdx], sigmaMH_zeta, dataclass);
713-
}
714-
if (updateIdx0_rev.n_elem > 0) {
715-
logProposalRatio += MALAlogPzetas(zetas_, proposedZeta, updateIdx0, componentUpdateIdx,
716-
datTheta, weibullS, weibullLambda, kappa_, wSq_[componentUpdateIdx], sigmaMH_zeta, dataclass);
878+
if (rw_mh == "mala") {
879+
if (updateIdx0.n_elem > 0) {
880+
logProposalRatio -= MALAzetas(proposedZeta, zetas_, updateIdx0, componentUpdateIdx,
881+
datTheta, weibullS, weibullLambda, kappa_, wSq_[componentUpdateIdx], sigmaMH_zeta, dataclass);
882+
}
883+
if (updateIdx0_rev.n_elem > 0) {
884+
logProposalRatio += MALAlogPzetas(zetas_, proposedZeta, updateIdx0, componentUpdateIdx,
885+
datTheta, weibullS, weibullLambda, kappa_, wSq_[componentUpdateIdx], sigmaMH_zeta, dataclass);
886+
}
887+
} else {
888+
logProposalRatio += 0.;
717889
}
718-
719890
} else {
720891
// (symmetric) random-walk Metropolis with optimal standard deviation O(d^{-1/2}, theoretically 2.38*d^{-1/2})
721892

@@ -764,9 +935,9 @@ void BVS_Sampler::sampleEta(
764935
double logLikelihoodRatio = arma::sum(proposedLikelihood - log_likelihood_);
765936

766937
// Here we need always compute the proposal and original ratios, in particular the likelihood, since betas are updated
767-
double logAccProb = logLikelihoodRatio +
938+
double logAccProb = //logLikelihoodRatio +
768939
logPriorEtaRatio +
769-
logPriorZetaRatio +
940+
//logPriorZetaRatio +
770941
logProposalRatio;
771942

772943
if( std::log(R::runif(0,1)) < logAccProb )

src/drive.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,17 @@ Rcpp::List run_mcmc(
533533
weibullS
534534
); */
535535
// xi_mcmc.row(1+m) = xi.t();
536+
537+
// update likelihood
538+
BVS_Sampler::loglikelihood(
539+
xi,
540+
zetas,
541+
betas,
542+
kappa,
543+
proportion_model,
544+
dataclass,
545+
log_likelihood
546+
);
536547

537548
// update cure rate based on the new xi
538549
logTheta = dataclass.datX0 * xi;
@@ -651,6 +662,17 @@ Rcpp::List run_mcmc(
651662
dataclass
652663
); // if n>1, here it will be an average
653664
// kappa_mcmc[1+m] = kappa;
665+
666+
// update likelihood
667+
BVS_Sampler::loglikelihood(
668+
xi,
669+
zetas,
670+
betas,
671+
kappa,
672+
proportion_model,
673+
dataclass,
674+
log_likelihood
675+
);
654676

655677
// update Weibull's quantities based on the new kappa
656678
for(unsigned int l=0; l<L; ++l)
@@ -732,6 +754,17 @@ Rcpp::List run_mcmc(
732754
// update \betas' variance tauSq
733755
// hyperpar->tauSq = sampleTau(hyperpar->tauA, hyperpar->tauB, betas);
734756
// tauSq_mcmc[1+m] = hyperpar->tauSq;
757+
758+
// update likelihood
759+
BVS_Sampler::loglikelihood(
760+
xi,
761+
zetas,
762+
betas,
763+
kappa,
764+
proportion_model,
765+
dataclass,
766+
log_likelihood
767+
);
735768

736769
#ifdef _OPENMP
737770
#pragma omp parallel for
@@ -796,19 +829,6 @@ Rcpp::List run_mcmc(
796829
eta_mcmc.row(1+nIter_thin_count) = arma::vectorise(etas).t();
797830
}
798831
}
799-
else
800-
{
801-
// update likelihood
802-
BVS_Sampler::loglikelihood(
803-
xi,
804-
zetas,
805-
betas,
806-
kappa,
807-
proportion_model,
808-
dataclass,
809-
log_likelihood
810-
);
811-
}
812832
// save loglikelihoods
813833
loglikelihood_mcmc.row(1+nIter_thin_count) = log_likelihood.t();
814834

0 commit comments

Comments
 (0)