Skip to content

Commit eda2613

Browse files
authored
Test: add UT for hsolver_pw_sdft (#2011)
1 parent 1d6acbd commit eda2613

File tree

6 files changed

+342
-36
lines changed

6 files changed

+342
-36
lines changed

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,21 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
139139
hsolver::DiagoIterAssist<double>::PW_DIAG_THR = ethr;
140140
hsolver::DiagoIterAssist<double>::PW_DIAG_NMAX = GlobalV::PW_DIAG_NMAX;
141141
this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec,this->stowf, istep, iter, GlobalV::KS_SOLVER);
142+
if(GlobalV::MY_STOGROUP==0)
143+
{
144+
Symmetry_rho srho;
145+
for(int is=0; is < GlobalV::NSPIN; is++)
146+
{
147+
srho.begin(is, *(this->pelec->charge), GlobalC::rhopw, GlobalC::Pgrid, GlobalC::symm);
148+
}
149+
GlobalC::en.deband = GlobalC::en.delta_e(this->pelec);
150+
}
151+
else
152+
{
153+
#ifdef __MPI
154+
if(ModuleSymmetry::Symmetry::symm_flag == 1) MPI_Barrier(MPI_COMM_WORLD);
155+
#endif
156+
}
142157
// transform energy for print
143158
GlobalC::en.eband = this->pelec->eband;
144159
GlobalC::en.demet = this->pelec->demet;

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
#include "module_base/timer.h"
66
#include "module_base/tool_title.h"
77
#include <algorithm>
8-
9-
//temporary
10-
#include "module_hamilt_pw/hamilt_pwdft/global.h"
118
namespace hsolver
129
{
1310
void HSolverPW_SDFT::solve(hamilt::Hamilt<double>* pHamilt,
@@ -58,13 +55,6 @@ namespace hsolver
5855
stoiter.orthog(ik,psi,stowf);
5956
stoiter.checkemm(ik,istep, iter, stowf); //check and reset emax & emin
6057
}
61-
// DiagoCG would keep 9*nbasis memory in cache during loop-k
62-
// it should be deleted before calculating charge
63-
if(this->method == "cg")
64-
{
65-
delete pdiagh;
66-
pdiagh = nullptr;
67-
}
6858

6959
this->endDiagh();
7060

@@ -96,7 +86,7 @@ namespace hsolver
9686
{
9787
for(int is=0; is < GlobalV::NSPIN; is++)
9888
{
99-
ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[is], GlobalC::rhopw->nrxx);
89+
ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[is], pes->charge->nrxx);
10090
}
10191
}
10292
// calculate stochastic rho
@@ -108,25 +98,7 @@ namespace hsolver
10898
// mohan add 2009-01-23
10999
//en.calculate_harris();
110100

111-
if(GlobalV::MY_STOGROUP==0)
112-
{
113-
Symmetry_rho srho;
114-
for(int is=0; is < GlobalV::NSPIN; is++)
115-
{
116-
srho.begin(is, *(pes->charge), GlobalC::rhopw, GlobalC::Pgrid, GlobalC::symm);
117-
}
118-
}
119-
else
120-
{
121-
#ifdef __MPI
122-
if(ModuleSymmetry::Symmetry::symm_flag == 1) MPI_Barrier(MPI_COMM_WORLD);
123-
#endif
124-
}
125-
126-
if(GlobalV::MY_STOGROUP == 0)
127-
{
128-
GlobalC::en.deband = GlobalC::en.delta_e(pes);
129-
}
101+
//will do rho symmetry and energy calculation in esolver
130102
ModuleBase::timer::tick(this->classname, "solve");
131103
return;
132104
}

source/module_hsolver/test/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ AddTest(
2929
SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp
3030
)
3131

32+
AddTest(
33+
TARGET HSolver_sdft
34+
LIBS ${math_libs} psi device base
35+
SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp
36+
)
37+
3238
if(ENABLE_LCAO)
3339
if(USE_ELPA)
3440
AddTest(

source/module_hsolver/test/hsolver_pw_sup.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,7 @@ void diago_PAO_in_pw_k2(const psi::DEVICE_CPU* ctx, const int &ik, psi::Psi<std:
220220
}
221221
}
222222

223-
}
223+
}//namespace hsolver
224+
225+
template class hsolver::HSolverPW<float, psi::DEVICE_CPU>;
226+
template class hsolver::HSolverPW<double, psi::DEVICE_CPU>;

source/module_hsolver/test/test_hsolver_pw.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,8 @@
1111

1212
#include "module_base/global_variable.h"
1313

14-
template class hsolver::HSolverPW<float, psi::DEVICE_CPU>;
15-
template class hsolver::HSolverPW<double, psi::DEVICE_CPU>;
16-
1714
/************************************************
18-
* unit test of HSolver base class
15+
* unit test of HSolverPW class
1916
***********************************************/
2017

2118
/**
@@ -85,7 +82,7 @@ TEST_F(TestHSolverPW, solve)
8582
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<double>::avg_iter, 0.0);
8683
for(int i=0;i<psi_test_cd.size();i++)
8784
{
88-
EXPECT_DOUBLE_EQ(psi_test_cf.get_pointer()[i].real(), i+3);
85+
EXPECT_DOUBLE_EQ(psi_test_cd.get_pointer()[i].real(), i+3);
8986
}
9087
EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[0], 4.0);
9188
EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[1], 7.0);

0 commit comments

Comments
 (0)