Skip to content

Commit aa52a32

Browse files
authored
Test: add test for pw ldos (#6109)
* Refactor: add class Cal_ldos * Test: add tests for pw ldos * fix complie bug * reduce size of LDOS.cube.ref
1 parent a338c25 commit aa52a32

File tree

15 files changed

+332
-36
lines changed

15 files changed

+332
-36
lines changed

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "module_hamilt_lcao/module_deltaspin/spin_constrain.h"
99
#include "module_hamilt_lcao/module_dftu/dftu.h"
1010
#include "module_io/berryphase.h"
11+
#include "module_io/cal_ldos.h"
1112
#include "module_io/cube_io.h"
1213
#include "module_io/dos_nao.h"
1314
#include "module_io/io_dmk.h"
@@ -419,6 +420,15 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
419420
this->p_hamilt);
420421
}
421422

423+
// out ldos
424+
if (PARAM.inp.out_ldos[0])
425+
{
426+
ModuleIO::Cal_ldos<TK>::cal_ldos_lcao(reinterpret_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec),
427+
this->psi[0],
428+
this->Pgrid,
429+
ucell);
430+
}
431+
422432
// 6) print out exchange-correlation potential
423433
if (PARAM.inp.out_mat_xc)
424434
{

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -853,12 +853,13 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
853853
}
854854

855855
// out ldos
856-
if (PARAM.inp.out_ldos)
856+
if (PARAM.inp.out_ldos[0])
857857
{
858-
ModuleIO::cal_ldos(reinterpret_cast<elecstate::ElecStatePW<std::complex<double>>*>(this->pelec),
859-
this->psi[0],
860-
this->Pgrid,
861-
ucell);
858+
ModuleIO::Cal_ldos<std::complex<double>>::cal_ldos_pw(
859+
reinterpret_cast<elecstate::ElecStatePW<std::complex<double>>*>(this->pelec),
860+
this->psi[0],
861+
this->Pgrid,
862+
ucell);
862863
}
863864

864865
//! 5) Calculate the spillage value, used to generate numerical atomic orbitals

source/module_io/cal_ldos.cpp

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
#include "cal_ldos.h"
22

33
#include "cube_io.h"
4+
#include "module_base/blas_connector.h"
5+
#include "module_base/scalapack_connector.h"
6+
7+
#include <type_traits>
48

59
namespace ModuleIO
610
{
7-
void cal_ldos(const elecstate::ElecStatePW<std::complex<double>>* pelec,
8-
const psi::Psi<std::complex<double>>& psi,
9-
const Parallel_Grid& pgrid,
10-
const UnitCell& ucell)
11+
template <typename T>
12+
void Cal_ldos<T>::cal_ldos_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,
13+
const psi::Psi<std::complex<double>>& psi,
14+
const Parallel_Grid& pgrid,
15+
const UnitCell& ucell)
1116
{
1217
// energy range for ldos (efermi as reference)
1318
const double emin = PARAM.inp.stm_bias < 0 ? PARAM.inp.stm_bias : 0;
@@ -19,13 +24,13 @@ void cal_ldos(const elecstate::ElecStatePW<std::complex<double>>* pelec,
1924
for (int ik = 0; ik < pelec->klist->get_nks(); ++ik)
2025
{
2126
psi.fix_k(ik);
22-
double efermi = pelec->eferm.get_efval(pelec->klist->isk[ik]);
27+
const double efermi = pelec->eferm.get_efval(pelec->klist->isk[ik]);
2328
int nbands = psi.get_nbands();
2429

2530
for (int ib = 0; ib < nbands; ib++)
2631
{
2732
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
28-
double eigenval = (pelec->ekb(ik, ib) - efermi) * ModuleBase::Ry_to_eV;
33+
const double eigenval = (pelec->ekb(ik, ib) - efermi) * ModuleBase::Ry_to_eV;
2934
if (eigenval >= emin && eigenval <= emax)
3035
{
3136
for (int ir = 0; ir < pelec->basis->nrxx; ir++)
@@ -38,18 +43,90 @@ void cal_ldos(const elecstate::ElecStatePW<std::complex<double>>* pelec,
3843
fn << PARAM.globalv.global_out_dir << "LDOS_" << PARAM.inp.stm_bias << "eV"
3944
<< ".cube";
4045

41-
ModuleIO::write_vdata_palgrid(pgrid, ldos.data(), 0, PARAM.inp.nspin, 0, fn.str(), 0, &ucell, 11, 0);
46+
const int precision = PARAM.inp.out_ldos[1];
47+
ModuleIO::write_vdata_palgrid(pgrid, ldos.data(), 0, PARAM.inp.nspin, 0, fn.str(), 0, &ucell, precision, 0);
4248
}
4349

4450
#ifdef __LCAO
45-
// lcao multi-k case
46-
// void cal_ldos(elecstate::ElecState* pelec, const psi::Psi<std::complex<double>>& psi, std::vector<double>& ldos)
47-
// {
48-
// }
49-
50-
// // lcao Gamma_only case
51-
// void cal_ldos(elecstate::ElecState* pelec, const psi::Psi<double>& psi, std::vector<double>& ldos)
52-
// {
53-
// }
51+
template <typename T>
52+
void Cal_ldos<T>::cal_ldos_lcao(const elecstate::ElecStateLCAO<T>* pelec,
53+
const psi::Psi<T>& psi,
54+
const Parallel_Grid& pgrid,
55+
const UnitCell& ucell)
56+
{
57+
// energy range for ldos (efermi as reference)
58+
const double emin = PARAM.inp.stm_bias < 0 ? PARAM.inp.stm_bias : 0;
59+
const double emax = PARAM.inp.stm_bias > 0 ? PARAM.inp.stm_bias : 0;
60+
61+
// calulate dm-like
62+
const int nbands_local = psi.get_nbands();
63+
const int nbasis_local = psi.get_nbasis();
64+
65+
// psi.T * wk * psi.conj()
66+
// result[ik](iw1,iw2) = \sum_{ib} psi[ik](ib,iw1).T * wk(k) * psi[ik](ib,iw2).conj()
67+
for (int ik = 0; ik < psi.get_nk(); ++ik)
68+
{
69+
psi.fix_k(ik);
70+
const double efermi = pelec->eferm.get_efval(pelec->klist->isk[ik]);
71+
72+
// T* dmk_pointer = DM.get_DMK_pointer(ik);
73+
74+
psi::Psi<T> wk_psi(1, psi.get_nbands(), psi.get_nbasis(), psi.get_nbasis(), true);
75+
const T* ppsi = psi.get_pointer();
76+
T* pwk_psi = wk_psi.get_pointer();
77+
78+
// #ifdef _OPENMP
79+
// #pragma omp parallel for schedule(static, 1024)
80+
// #endif
81+
// for (int i = 0; i < wk_psi.size(); ++i)
82+
// {
83+
// pwk_psi[i] = my_conj(ppsi[i]);
84+
// }
85+
86+
// int ib_global = 0;
87+
// for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
88+
// {
89+
// while (ib_local != ParaV->global2local_col(ib_global))
90+
// {
91+
// ++ib_global;
92+
// if (ib_global >= wg.nc)
93+
// {
94+
// ModuleBase::WARNING_QUIT("cal_ldos", "please check global2local_col!");
95+
// }
96+
// }
97+
98+
// const double eigenval = (pelec->ekb(ik, ib_global) - efermi) * ModuleBase::Ry_to_eV;
99+
// if (eigenval >= emin && eigenval <= emax)
100+
// {
101+
// for (int ir = 0; ir < pelec->basis->nrxx; ir++)
102+
// ldos[ir] += pelec->klist->wk[ik] * norm(wfcr[ir]);
103+
// }
104+
105+
// double* wg_wfc_pointer = &(wk_psi(0, ib_local, 0));
106+
// BlasConnector::scal(nbasis_local, pelec->klist->wk[ik], wg_wfc_pointer, 1);
107+
// }
108+
109+
// // C++: dm(iw1,iw2) = psi(ib,iw1).T * wk_psi(ib,iw2)
110+
// #ifdef __MPI
111+
// psiMulPsiMpi(wk_psi, psi, dmk_pointer, ParaV->desc_wfc, ParaV->desc);
112+
// #else
113+
// psiMulPsi(wk_psi, psi, dmk_pointer);
114+
// #endif
115+
}
116+
}
117+
118+
double my_conj(double x)
119+
{
120+
return x;
121+
}
122+
123+
std::complex<double> my_conj(const std::complex<double>& z)
124+
{
125+
return {z.real(), -z.imag()};
126+
}
127+
54128
#endif
129+
130+
template class Cal_ldos<double>; // Gamma_only case
131+
template class Cal_ldos<std::complex<double>>; // multi-k case
55132
} // namespace elecstate

source/module_io/cal_ldos.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,29 @@
11
#ifndef CAL_LDOS_H
22
#define CAL_LDOS_H
33

4+
#include "module_elecstate/elecstate_lcao.h"
45
#include "module_elecstate/elecstate_pw.h"
56

67
namespace ModuleIO
78
{
9+
template <typename T>
10+
class Cal_ldos
11+
{
12+
public:
13+
Cal_ldos(){};
14+
~Cal_ldos(){};
15+
16+
static void cal_ldos_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,
17+
const psi::Psi<std::complex<double>>& psi,
18+
const Parallel_Grid& pgrid,
19+
const UnitCell& ucell);
820

9-
void cal_ldos(const elecstate::ElecStatePW<std::complex<double>>* pelec,
10-
const psi::Psi<std::complex<double>>& psi,
11-
const Parallel_Grid& pgrid,
12-
const UnitCell& ucell);
21+
static void cal_ldos_lcao(const elecstate::ElecStateLCAO<T>* pelec,
22+
const psi::Psi<T>& psi,
23+
const Parallel_Grid& pgrid,
24+
const UnitCell& ucell);
1325

26+
}; // namespace Cal_ldos
1427
} // namespace ModuleIO
1528

1629
#endif // CAL_LDOS_H

source/module_io/read_input_item_output.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,17 @@ void ReadInput::item_output()
145145
}
146146
{
147147
Input_Item item("out_ldos");
148-
item.annotation = "output local density of states";
149-
read_sync_bool(input.out_ldos);
148+
item.annotation = "output local density of states, second parameter controls the precision";
149+
item.read_value = [](const Input_Item& item, Parameter& para) {
150+
const size_t count = item.get_size();
151+
if (count != 1 && count != 2)
152+
{
153+
ModuleBase::WARNING_QUIT("ReadInput", "out_ldos should have 1 or 2 values");
154+
}
155+
para.input.out_ldos[0] = assume_as_boolean(item.str_values[0]);
156+
para.input.out_ldos[1] = (count == 2) ? std::stoi(item.str_values[1]) : 3;
157+
};
158+
sync_intvec(input.out_ldos, 2, 0);
150159
this->add_item(item);
151160
}
152161
{

source/module_io/read_input_item_postprocess.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void ReadInput::item_postprocess()
5858
item.annotation = "bias voltage used to calculate ldos";
5959
read_sync_double(input.stm_bias);
6060
item.check_value = [](const Input_Item& item, const Parameter& para) {
61-
if (para.input.out_ldos && para.input.stm_bias == 0.0)
61+
if (para.input.out_ldos[0] && para.input.stm_bias == 0.0)
6262
{
6363
ModuleBase::WARNING_QUIT("ReadInput", "a nonzero stm_bias is required for ldos calculation");
6464
}

source/module_io/test/read_input_ptest.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ TEST_F(InputParaTest, ParaRead)
201201
EXPECT_EQ(param.inp.out_wfc_pw, 0);
202202
EXPECT_EQ(param.inp.out_wfc_r, 0);
203203
EXPECT_EQ(param.inp.out_dos, 0);
204-
EXPECT_EQ(param.inp.out_ldos, true);
204+
EXPECT_EQ(param.inp.out_ldos[0], 1);
205+
EXPECT_EQ(param.inp.out_ldos[1], 3);
205206
EXPECT_EQ(param.inp.out_band[0], 0);
206207
EXPECT_EQ(param.inp.out_band[1], 8);
207208
EXPECT_EQ(param.inp.out_proj_band, 0);

source/module_io/test/support/INPUT

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ out_pot 2 #output realspace potential
6565
out_wfc_pw 0 #output wave functions
6666
out_wfc_r 0 #output wave functions in realspace
6767
out_dos 0 #output energy and dos
68-
out_ldos True #output local density of states
68+
out_ldos 1 #output local density of states, second parameter controls the precision
6969
out_band 0 #output energy and band structure
7070
out_proj_band FaLse #output projected band structure
7171
restart_save f #print to disk every step for restart

source/module_io/test_serial/read_input_item_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ TEST_F(InputTest, Item_test)
393393
}
394394
{ // stm_bias
395395
auto it = find_label("stm_bias", readinput.input_lists);
396-
param.input.out_ldos = true;
396+
param.input.out_ldos[0] = 1;
397397
param.input.stm_bias = 0.0;
398398

399399
testing::internal::CaptureStdout();

source/module_parameter/input_parameter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ struct Input_para
363363
int printe = 0; ///< Print out energy for each band for every printe step, default is scf_nmax
364364
std::vector<int> out_band = {0, 8}; ///< band calculation pengfei 2014-10-13
365365
int out_dos = 0; ///< dos calculation. mohan add 20090909
366-
bool out_ldos = false; ///< ldos calculation
366+
std::vector<int> out_ldos = {0, 3}; ///< ldos calculation
367367
bool out_mul = false; ///< qifeng add 2019-9-10
368368
bool out_proj_band = false; ///< projected band structure calculation jiyy add 2022-05-11
369369
std::string out_level = "ie"; ///< control the output information.

0 commit comments

Comments
 (0)