1212#include " module_base/global_variable.h"
1313#include " module_hsolver/hsolver_pw.h"
1414#include " module_hsolver/hsolver_pw_sdft.h"
15+ #include " module_elecstate/elecstate_pw.h"
1516#undef private
1617#undef protected
1718
@@ -20,16 +21,20 @@ template <typename REAL>
2021Sto_Func<REAL>::Sto_Func()
2122{
2223}
23-
2424template class Sto_Func <double >;
2525
26- template <typename REAL>
27- StoChe<REAL>::StoChe(const int & nche, const int & method, const REAL& emax_sto, const REAL& emin_sto)
26+ template <>
27+ void elecstate::ElecStatePW<std::complex <double >, base_device::DEVICE_CPU>::init_rho_data()
28+ {
29+ }
30+
31+ template <typename REAL, typename Device>
32+ StoChe<REAL, Device>::StoChe(const int & nche, const int & method, const REAL& emax_sto, const REAL& emin_sto)
2833{
2934 this ->nche = nche;
3035}
31- template <typename REAL>
32- StoChe<REAL>::~StoChe ()
36+ template <typename REAL, typename Device >
37+ StoChe<REAL, Device >::~StoChe ()
3338{
3439}
3540
@@ -51,7 +56,7 @@ template <typename T, typename Device>
5156void Stochastic_Iter<T, Device>::init(K_Vectors* pkv_in,
5257 ModulePW::PW_Basis_K* wfc_basis,
5358 Stochastic_WF<T, Device>& stowf,
54- StoChe<double >& stoche,
59+ StoChe<Real, Device >& stoche,
5560 hamilt::HamiltSdftPW<T, Device>* p_hamilt_sto)
5661{
5762 this ->nchip = stowf.nchip ;
@@ -108,7 +113,7 @@ void Stochastic_Iter<T, Device>::calHsqrtchi(Stochastic_WF<T, Device>& stowf)
108113
109114template <typename T, typename Device>
110115void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
111- elecstate::ElecStatePW* pes,
116+ elecstate::ElecStatePW<T, Device> * pes,
112117 hamilt::Hamilt<T, Device>* pHamilt,
113118 ModulePW::PW_Basis_K* wfc_basis)
114119{
@@ -193,7 +198,7 @@ TEST_F(TestHSolverPW_SDFT, solve)
193198 int istep = 0 ;
194199 int iter = 0 ;
195200
196- this ->hs_d .solve (&hamilt_test_d, psi_test_cd, &elecstate_test, &pwbk, stowf, istep, iter, false );
201+ this ->hs_d .solve (&hamilt_test_d, psi_test_cd, psi_test_cd, &elecstate_test, &pwbk, stowf, istep, iter, false );
197202 EXPECT_DOUBLE_EQ (hsolver::DiagoIterAssist<std::complex <double >>::avg_iter, 0.0 );
198203 EXPECT_DOUBLE_EQ (elecstate_test.ekb .c [0 ], 4.0 );
199204 EXPECT_DOUBLE_EQ (elecstate_test.ekb .c [1 ], 7.0 );
@@ -237,7 +242,7 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge)
237242 int istep = 0 ;
238243 int iter = 0 ;
239244
240- this ->hs_d .solve (&hamilt_test_d, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, false );
245+ this ->hs_d .solve (&hamilt_test_d, psi_test_no, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, false );
241246 EXPECT_DOUBLE_EQ (hsolver::DiagoIterAssist<std::complex <double >>::avg_iter, 0.0 );
242247 EXPECT_EQ (stowf.nbands_diag , 2 );
243248 EXPECT_EQ (stowf.nbands_total , 1 );
@@ -251,7 +256,7 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge)
251256 std::cout<<__FILE__<<__LINE__<<" "<<elecstate_test.f_en.eband<<std::endl;*/
252257
253258 // test for skip charge
254- this ->hs_d .solve (&hamilt_test_d, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, true );
259+ this ->hs_d .solve (&hamilt_test_d, psi_test_no, psi_test_no, &elecstate_test, &pwbk, stowf, istep, iter, true );
255260 EXPECT_EQ (stowf.nbands_diag , 4 );
256261 EXPECT_EQ (stowf.nbands_total , 1 );
257262 EXPECT_EQ (stowf.nchi , 4 );
0 commit comments