Skip to content

Commit 5ae8987

Browse files
committed
optimize BVS_dMVP::sampleZ()
1 parent 7d87baa commit 5ae8987

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

src/BVS_dMVP.cpp

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -976,12 +976,12 @@ void BVS_dMVP::sampleZ(
976976

977977
// arma::mat Psi = arma::zeros<arma::mat>(L, L);
978978
// updatePsi( SigmaRho, Psi );
979-
arma::mat D = arma::sqrt(arma::diagmat(Psi));
979+
arma::vec dinv = 1.0 / arma::sqrt(Psi.diag());
980+
arma::mat Dinv = arma::diagmat(dinv);
980981

981-
// std::cout << "sampleZ(): D=\n" << D << "\n Psi=\n" << Psi << "\n";
982-
arma::mat Dinv = arma::inv_sympd( D );
983982
// std::cout << "...Dinv=\n" << Dinv << "\n";
984983
arma::mat RR = Dinv * Psi * Dinv;
984+
RR = 0.5 * (RR + RR.t()); // enforce symmetry numerically
985985
arma::mat Rinv;
986986
if( !arma::inv_sympd( Rinv, RR ) )
987987
{
@@ -994,38 +994,45 @@ void BVS_dMVP::sampleZ(
994994

995995
arma::mat Z = arma::zeros<arma::mat>(N, L); // reset all entries
996996
arma::uvec singleIdx_k;
997+
arma::uvec all_idx = arma::linspace<arma::uvec>(0, L-1, L);
998+
997999

9981000
for( unsigned int k=0; k<L; ++k)
9991001
{
10001002
// Z.col(k) = zbinprobit( dataclass.y.col(k), Mus.col(k) );
10011003

10021004
arma::uvec singleIdx_k = {k};
1003-
arma::uvec excludeIdx_k = arma::linspace<arma::uvec>(0, L-1, L);
1005+
arma::uvec excludeIdx_k = all_idx;
10041006
excludeIdx_k.shed_row(k);
1005-
for( unsigned int i=0; i<N; ++i)
1006-
{
1007-
// if( (dataclass.y(i, k) == 1. && Z0(i, k) > 0) || (dataclass.y(i, k) == 0. && Z0(i, k) < 0) )
1008-
// // if( (dataclass.y(i, k) == 1. && Mus(i, k) > 0) || (dataclass.y(i, k) == 0. && Mus(i, k) < 0) )
1009-
// {
1010-
arma::uvec singleIdx_i = {i};
10111007

1012-
//std::cout << "Debug sampleZ042";
1013-
arma::mat invR_excludeIdx_k;
1014-
if( !arma::inv_sympd( invR_excludeIdx_k, RR.submat(excludeIdx_k, excludeIdx_k) ) )
1015-
{
1016-
arma::inv(invR_excludeIdx_k, RR.submat(excludeIdx_k, excludeIdx_k), arma::inv_opts::allow_approx);
1008+
arma::mat Rmm = RR.submat(excludeIdx_k, excludeIdx_k);
1009+
arma::mat chol_Rmm;
1010+
if(!arma::chol(chol_Rmm, Rmm, "lower"))
1011+
{
1012+
arma::mat Rmm_jit = Rmm + 1.0e-10 * arma::eye<arma::mat>(Rmm.n_rows, Rmm.n_cols);
1013+
if (!arma::chol(chol_Rmm, Rmm_jit, "lower")) {
1014+
throw std::runtime_error("Cholesky failed for R_{-k,-k}");
10171015
}
1016+
}
1017+
1018+
arma::rowvec Rkm = RR.submat(singleIdx_k, excludeIdx_k);
1019+
arma::vec Rmk = RR.submat(excludeIdx_k, singleIdx_k);
1020+
// Solve R_{-k,-k}^{-1} R_{-k,k}
1021+
arma::vec sol2 = arma::solve(arma::trimatl(chol_Rmm), Rmk, arma::solve_opts::fast);
1022+
sol2 = arma::solve(arma::trimatu(chol_Rmm.t()), sol2, arma::solve_opts::fast);
1023+
1024+
double var_k = 1.0 - arma::as_scalar(Rkm * sol2);
1025+
var_k = std::max(1.0e-16, var_k);
1026+
double sd_k = std::sqrt(var_k);
10181027

1019-
double mu_ik = Mus(i,k) +
1020-
arma::as_scalar(
1021-
RR.submat(singleIdx_k, excludeIdx_k) * invR_excludeIdx_k *
1022-
(Z0(singleIdx_i, excludeIdx_k) - Mus(singleIdx_i, excludeIdx_k) ).t()
1023-
);
1028+
for( unsigned int i=0; i<N; ++i)
1029+
{
1030+
arma::uvec singleIdx_i = {i};
1031+
arma::vec rhs = (Z0.submat(singleIdx_i, excludeIdx_k).t() - Mus.submat(singleIdx_i, excludeIdx_k).t());
1032+
arma::vec sol = arma::solve(arma::trimatl(chol_Rmm), rhs, arma::solve_opts::fast);
1033+
sol = arma::solve(arma::trimatu(chol_Rmm.t()), sol, arma::solve_opts::fast);
10241034

1025-
double sigmaZ_ik = 1.0 - arma::as_scalar(
1026-
RR.submat(singleIdx_k, excludeIdx_k) * invR_excludeIdx_k * RR.submat(excludeIdx_k, singleIdx_k)
1027-
);
1028-
sigmaZ_ik = std::sqrt( std::max(1.0e-16, sigmaZ_ik) );
1035+
double mu_ik = Mus(i,k) + arma::as_scalar( Rkm * sol );
10291036
/*
10301037
if( dataclass.y(i, k) == 1. && mu_ik > 0.)
10311038
{
@@ -1036,7 +1043,7 @@ void BVS_dMVP::sampleZ(
10361043
Z(i, k) = BVS_subfunc::randTruncNorm( mu_ik, sigmaZ_ik, -1.0e+6, 0. );
10371044
}
10381045
*/
1039-
Z(i,k) = zbinprobit(dataclass.y(i, k), mu_ik, sigmaZ_ik);
1046+
Z(i,k) = zbinprobit(dataclass.y(i, k), mu_ik, sd_k);
10401047

10411048
// }
10421049
}

src/BVS_iMVP.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ double BVS_iMVP::gibbs_betaK(
406406
arma::uvec singleIdx_k = {k};
407407
betas(VS_IN_k, singleIdx_k) = beta_mask;
408408

409+
return logP;
410+
409411
}
410412

411413
double BVS_iMVP::logP_gibbs_betaK(

0 commit comments

Comments
 (0)