@@ -976,12 +976,12 @@ void BVS_dMVP::sampleZ(
976976
977977 // arma::mat Psi = arma::zeros<arma::mat>(L, L);
978978 // updatePsi( SigmaRho, Psi );
979- arma::mat D = arma::sqrt (arma::diagmat (Psi));
979+ arma::vec dinv = 1.0 / arma::sqrt (Psi.diag ());
980+ arma::mat Dinv = arma::diagmat (dinv);
980981
981- // std::cout << "sampleZ(): D=\n" << D << "\n Psi=\n" << Psi << "\n";
982- arma::mat Dinv = arma::inv_sympd ( D );
983982 // std::cout << "...Dinv=\n" << Dinv << "\n";
984983 arma::mat RR = Dinv * Psi * Dinv;
984+ RR = 0.5 * (RR + RR.t ()); // enforce symmetry numerically
985985 arma::mat Rinv;
986986 if ( !arma::inv_sympd ( Rinv, RR ) )
987987 {
@@ -994,38 +994,45 @@ void BVS_dMVP::sampleZ(
994994
995995 arma::mat Z = arma::zeros<arma::mat>(N, L); // reset all entries
996996 arma::uvec singleIdx_k;
997+ arma::uvec all_idx = arma::linspace<arma::uvec>(0 , L-1 , L);
998+
997999
9981000 for ( unsigned int k=0 ; k<L; ++k)
9991001 {
10001002 // Z.col(k) = zbinprobit( dataclass.y.col(k), Mus.col(k) );
10011003
10021004 arma::uvec singleIdx_k = {k};
1003- arma::uvec excludeIdx_k = arma::linspace<arma::uvec>( 0 , L- 1 , L) ;
1005+ arma::uvec excludeIdx_k = all_idx ;
10041006 excludeIdx_k.shed_row (k);
1005- for ( unsigned int i=0 ; i<N; ++i)
1006- {
1007- // if( (dataclass.y(i, k) == 1. && Z0(i, k) > 0) || (dataclass.y(i, k) == 0. && Z0(i, k) < 0) )
1008- // // if( (dataclass.y(i, k) == 1. && Mus(i, k) > 0) || (dataclass.y(i, k) == 0. && Mus(i, k) < 0) )
1009- // {
1010- arma::uvec singleIdx_i = {i};
10111007
1012- // std::cout << "Debug sampleZ042";
1013- arma::mat invR_excludeIdx_k;
1014- if ( !arma::inv_sympd ( invR_excludeIdx_k, RR.submat (excludeIdx_k, excludeIdx_k) ) )
1015- {
1016- arma::inv (invR_excludeIdx_k, RR.submat (excludeIdx_k, excludeIdx_k), arma::inv_opts::allow_approx);
1008+ arma::mat Rmm = RR.submat (excludeIdx_k, excludeIdx_k);
1009+ arma::mat chol_Rmm;
1010+ if (!arma::chol (chol_Rmm, Rmm, " lower" ))
1011+ {
1012+ arma::mat Rmm_jit = Rmm + 1.0e-10 * arma::eye<arma::mat>(Rmm.n_rows , Rmm.n_cols );
1013+ if (!arma::chol (chol_Rmm, Rmm_jit, " lower" )) {
1014+ throw std::runtime_error (" Cholesky failed for R_{-k,-k}" );
10171015 }
1016+ }
1017+
1018+ arma::rowvec Rkm = RR.submat (singleIdx_k, excludeIdx_k);
1019+ arma::vec Rmk = RR.submat (excludeIdx_k, singleIdx_k);
1020+ // Solve R_{-k,-k}^{-1} R_{-k,k}
1021+ arma::vec sol2 = arma::solve (arma::trimatl (chol_Rmm), Rmk, arma::solve_opts::fast);
1022+ sol2 = arma::solve (arma::trimatu (chol_Rmm.t ()), sol2, arma::solve_opts::fast);
1023+
1024+ double var_k = 1.0 - arma::as_scalar (Rkm * sol2);
1025+ var_k = std::max (1.0e-16 , var_k);
1026+ double sd_k = std::sqrt (var_k);
10181027
1019- double mu_ik = Mus (i,k) +
1020- arma::as_scalar (
1021- RR.submat (singleIdx_k, excludeIdx_k) * invR_excludeIdx_k *
1022- (Z0 (singleIdx_i, excludeIdx_k) - Mus (singleIdx_i, excludeIdx_k) ).t ()
1023- );
1028+ for ( unsigned int i=0 ; i<N; ++i)
1029+ {
1030+ arma::uvec singleIdx_i = {i};
1031+ arma::vec rhs = (Z0.submat (singleIdx_i, excludeIdx_k).t () - Mus.submat (singleIdx_i, excludeIdx_k).t ());
1032+ arma::vec sol = arma::solve (arma::trimatl (chol_Rmm), rhs, arma::solve_opts::fast);
1033+ sol = arma::solve (arma::trimatu (chol_Rmm.t ()), sol, arma::solve_opts::fast);
10241034
1025- double sigmaZ_ik = 1.0 - arma::as_scalar (
1026- RR.submat (singleIdx_k, excludeIdx_k) * invR_excludeIdx_k * RR.submat (excludeIdx_k, singleIdx_k)
1027- );
1028- sigmaZ_ik = std::sqrt ( std::max (1.0e-16 , sigmaZ_ik) );
1035+ double mu_ik = Mus (i,k) + arma::as_scalar ( Rkm * sol );
10291036 /*
10301037 if( dataclass.y(i, k) == 1. && mu_ik > 0.)
10311038 {
@@ -1036,7 +1043,7 @@ void BVS_dMVP::sampleZ(
10361043 Z(i, k) = BVS_subfunc::randTruncNorm( mu_ik, sigmaZ_ik, -1.0e+6, 0. );
10371044 }
10381045 */
1039- Z (i,k) = zbinprobit (dataclass.y (i, k), mu_ik, sigmaZ_ik );
1046+ Z (i,k) = zbinprobit (dataclass.y (i, k), mu_ik, sd_k );
10401047
10411048 // }
10421049 }
0 commit comments