Skip to content

Commit c3ecca2

Browse files
committed
fix compile error and add UTs for hamilt_sdft
1 parent 2947197 commit c3ecca2

File tree

3 files changed

+149
-1
lines changed

3 files changed

+149
-1
lines changed

source/module_hamilt_pw/hamilt_pwdft/kernels/test/ekinetic_op_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ TEST_F(TestModuleHamiltEkinetic, ekinetic_pw_op_gpu)
8989
syncmem_d_h2d_op()(gpu_ctx, cpu_ctx, gk2_dev, gk2.data(), gk2.size());
9090
syncmem_cd_h2d_op()(gpu_ctx, cpu_ctx, psi_dev, psi.data(), psi.size());
9191
// ekinetic_cpu_op()(cpu_ctx, band, dim, dim, tpiba2, gk2.data(), hpsi.data(), psi.data());
92-
ekinetic_gpu_op()(gpu_ctx, band, dim, dim, tpiba2, gk2_dev, hpsi_dev, psi_dev);
92+
ekinetic_gpu_op()(gpu_ctx, band, dim, dim, false, tpiba2, gk2_dev, hpsi_dev, psi_dev);
9393
syncmem_cd_d2h_op()(cpu_ctx, gpu_ctx, hpsi.data(), hpsi_dev, hpsi.size());
9494

9595
for (int ii = 0; ii < hpsi.size(); ii++) {

source/module_hamilt_pw/hamilt_stodft/test/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,10 @@ AddTest(
44
TARGET Sto_Tool_UTs
55
LIBS parameter ${math_libs} psi base device
66
SOURCES ../sto_tool.cpp test_sto_tool.cpp
7+
)
8+
9+
AddTest(
10+
TARGET Sto_Hamilt_UTs
11+
LIBS parameter ${math_libs} psi base device planewave_serial
12+
SOURCES ../hamilt_sdft_pw.cpp test_hamilt_sto.cpp ../../../module_hamilt_general/operator.cpp
713
)
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#include "../hamilt_sdft_pw.h"
2+
#include "module_hamilt_general/operator.h"
3+
4+
#include "gtest/gtest.h"
5+
#include <vector>
6+
7+
K_Vectors::K_Vectors(){}
8+
K_Vectors::~K_Vectors(){}
9+
elecstate::Potential::~Potential(){}
10+
void elecstate::Potential::cal_v_eff(Charge const*, UnitCell const*, ModuleBase::matrix&){}
11+
void elecstate::Potential::cal_fixed_v(double*){}
12+
13+
template <typename T, typename Device>
14+
hamilt::HamiltPW<T, Device>::HamiltPW(elecstate::Potential* pot_in, ModulePW::PW_Basis_K* wfc_basis, K_Vectors* p_kv){}
15+
template <typename T, typename Device>
16+
hamilt::HamiltPW<T, Device>::~HamiltPW(){
17+
delete this->ops;
18+
};
19+
template <typename T, typename Device>
20+
void hamilt::HamiltPW<T, Device>::updateHk(int){}
21+
template <typename T, typename Device>
22+
void hamilt::HamiltPW<T, Device>::sPsi(T const*, T*, const int, const int, const int) const{}
23+
24+
template class hamilt::HamiltPW<std::complex<double>, base_device::DEVICE_CPU>;
25+
template class hamilt::HamiltPW<std::complex<float>, base_device::DEVICE_CPU>;
26+
27+
/************************************************
28+
* unit test of hamilt_sto_pw.cpp
29+
* - Tested Functions:
30+
* - void hPsi(const T* psi_in, T* hpsi, const int& nbands)
31+
* - void hPsi_norm(const T* psi_in, T* hpsi, const int& nbands)
32+
***********************************************/
33+
34+
template <typename T, typename Device>
35+
class TestOp : public hamilt::Operator<T, Device>
36+
{
37+
public:
38+
virtual void act(const int nbands,
39+
const int nbasis,
40+
const int npol,
41+
const T* tmpsi_in,
42+
T* tmhpsi,
43+
const int ngk_ik = 0,
44+
const bool is_first_node = false) const override
45+
{
46+
for (int i = 0; i < nbands; i++)
47+
{
48+
for (int j = 0; j < nbasis; j++)
49+
{
50+
tmhpsi[i * nbasis + j] = tmpsi_in[i * nbasis + j];
51+
}
52+
}
53+
}
54+
};
55+
56+
class TestHamiltSto : public ::testing::Test
57+
{
58+
public:
59+
TestHamiltSto()
60+
{
61+
const int nbands = 1;
62+
const int nbasis = 2;
63+
const int npol = 1;
64+
// Initialize the hamilt_sto_pw
65+
pot = new elecstate::Potential();
66+
wfc_basis = new ModulePW::PW_Basis_K();
67+
wfc_basis->npwk_max = 2;
68+
p_kv = new K_Vectors();
69+
std::vector<int> ngk = {2};
70+
p_kv->ngk = ngk;
71+
hamilt_sto = new hamilt::HamiltSdftPW<std::complex<double>, base_device::DEVICE_CPU>(pot, wfc_basis, p_kv, npol, &emin, &emax);
72+
hamilt_sto->ops = new TestOp<std::complex<double>, base_device::DEVICE_CPU>();
73+
}
74+
75+
~TestHamiltSto()
76+
{
77+
delete pot;
78+
delete wfc_basis;
79+
delete p_kv;
80+
delete hamilt_sto;
81+
}
82+
83+
elecstate::Potential* pot;
84+
ModulePW::PW_Basis_K* wfc_basis;
85+
K_Vectors* p_kv;
86+
hamilt::HamiltSdftPW<std::complex<double>, base_device::DEVICE_CPU>* hamilt_sto;
87+
double emin = -2.0;
88+
double emax = 2.0;
89+
};
90+
91+
TEST_F(TestHamiltSto, hPsi)
92+
{
93+
const int nbands = 1;
94+
const int nbasis = 2;
95+
// Prepare the input psi
96+
std::vector<std::complex<double>> psi_in(nbands * nbasis);
97+
std::vector<std::complex<double>> hpsi(nbands * nbasis);
98+
99+
for (int i = 0; i < nbands; i++)
100+
{
101+
for (int j = 0; j < nbasis; j++)
102+
{
103+
psi_in[i * nbasis + j] = i + j;
104+
}
105+
}
106+
hamilt_sto->hPsi(psi_in.data(), hpsi.data(), nbands);
107+
// Check the result
108+
for (int i = 0; i < nbands; i++)
109+
{
110+
for (int j = 0; j < nbasis; j++)
111+
{
112+
EXPECT_EQ(hpsi[i * nbasis + j], psi_in[i * nbasis + j]);
113+
}
114+
}
115+
}
116+
117+
TEST_F(TestHamiltSto, hPsi_norm)
118+
{
119+
int nbands = 1;
120+
int nbasis = 2;
121+
std::vector<std::complex<double>> psi_in(nbands * nbasis);
122+
std::vector<std::complex<double>> hpsi(nbands * nbasis);
123+
124+
for (int i = 0; i < nbands; i++)
125+
{
126+
for (int j = 0; j < nbasis; j++)
127+
{
128+
psi_in[i * nbasis + j] = i + j;
129+
}
130+
}
131+
132+
hamilt_sto->hPsi_norm(psi_in.data(), hpsi.data(), nbands);
133+
134+
// Check the result
135+
for (int i = 0; i < nbands; i++)
136+
{
137+
for (int j = 0; j < nbasis; j++)
138+
{
139+
EXPECT_EQ(hpsi[i * nbasis + j], psi_in[i * nbasis + j] * 0.5);
140+
}
141+
}
142+
}

0 commit comments

Comments
 (0)