Skip to content

Commit 6d4e946

Browse files
Refactor class DensityMatrix: remove the dependence on K_Vectors and the ambiguity of _nks (#5224)
* remove dependence of DensityMatrix on K_Vectors * rename _nks as _nk to avoid ambiguity * [pre-commit.ci lite] apply automatic fixes * remove map * make nk non-optional * fix after rebase --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 590ea5e commit 6d4e946

File tree

18 files changed

+93
-119
lines changed

18 files changed

+93
-119
lines changed

source/module_elecstate/elecstate_lcao.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ void ElecStateLCAO<double>::psiToRho(const psi::Psi<double>& psi)
131131
template <typename TK>
132132
void ElecStateLCAO<TK>::init_DM(const K_Vectors* kv, const Parallel_Orbitals* paraV, const int nspin)
133133
{
134-
this->DM = new DensityMatrix<TK, double>(kv, paraV, nspin);
134+
const int nspin_dm = nspin == 2 ? 2 : 1;
135+
this->DM = new DensityMatrix<TK, double>(paraV, nspin_dm, kv->kvec_d, kv->get_nks() / nspin_dm);
135136
}
136137

137138
template <>

source/module_elecstate/module_dm/density_matrix.cpp

Lines changed: 34 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -24,62 +24,24 @@ DensityMatrix<TK, TR>::~DensityMatrix()
2424
}
2525
}
2626

27-
// constructor for multi-k
2827
template <typename TK, typename TR>
29-
DensityMatrix<TK, TR>::DensityMatrix(const K_Vectors* kv_in, const Parallel_Orbitals* paraV_in, const int nspin)
28+
DensityMatrix<TK, TR>::DensityMatrix(const Parallel_Orbitals* paraV_in, const int nspin, const std::vector<ModuleBase::Vector3<double>>& kvec_d, const int nk)
29+
: _paraV(paraV_in), _nspin(nspin), _kvec_d(kvec_d), _nk((nk > 0 && nk <= _kvec_d.size()) ? nk : _kvec_d.size())
3030
{
3131
ModuleBase::TITLE("DensityMatrix", "DensityMatrix-MK");
32-
this->_kv = kv_in;
33-
this->_paraV = paraV_in;
34-
// set this->_nspin
35-
if (nspin == 1 || nspin == 4)
36-
{
37-
this->_nspin = 1;
38-
}
39-
else if (nspin == 2)
40-
{
41-
this->_nspin = 2;
42-
#ifdef __DEBUG
43-
assert(kv_in->get_nks() % 2 == 0);
44-
#endif
45-
}
46-
else
47-
{
48-
throw std::string("nspin must be 1, 2 or 4");
49-
}
50-
// set this->_nks, which is real number of k-points
51-
this->_nks = kv_in->get_nks() / this->_nspin;
52-
// allocate memory for _DMK
53-
this->_DMK.resize(this->_kv->get_nks());
54-
for (int ik = 0; ik < this->_kv->get_nks(); ik++)
32+
const int nks = _nk * _nspin;
33+
this->_DMK.resize(nks);
34+
for (int ik = 0; ik < nks; ik++)
5535
{
5636
this->_DMK[ik].resize(this->_paraV->get_row_size() * this->_paraV->get_col_size());
5737
}
5838
ModuleBase::Memory::record("DensityMatrix::DMK", this->_DMK.size() * this->_DMK[0].size() * sizeof(TK));
5939
}
6040

61-
// constructor for Gamma-Only
6241
template <typename TK, typename TR>
63-
DensityMatrix<TK, TR>::DensityMatrix(const Parallel_Orbitals* paraV_in, const int nspin)
42+
DensityMatrix<TK, TR>::DensityMatrix(const Parallel_Orbitals* paraV_in, const int nspin) :_paraV(paraV_in), _nspin(nspin), _kvec_d({ ModuleBase::Vector3<double>(0,0,0) }), _nk(1)
6443
{
6544
ModuleBase::TITLE("DensityMatrix", "DensityMatrix-GO");
66-
this->_paraV = paraV_in;
67-
// set this->_nspin
68-
if (nspin == 1 || nspin == 4)
69-
{
70-
this->_nspin = 1;
71-
}
72-
else if (nspin == 2)
73-
{
74-
this->_nspin = 2;
75-
}
76-
else
77-
{
78-
throw std::string("nspin must be 1, 2 or 4");
79-
}
80-
// set this->_nks, which is real number of k-points
81-
this->_nks = 1;
82-
// allocate memory for _DMK
8345
this->_DMK.resize(_nspin);
8446
for (int ik = 0; ik < this->_nspin; ik++)
8547
{
@@ -274,7 +236,7 @@ template <typename TK, typename TR>
274236
TK* DensityMatrix<TK, TR>::get_DMK_pointer(const int ik) const
275237
{
276238
#ifdef __DEBUG
277-
assert(ik < this->_nks * this->_nspin);
239+
assert(ik < this->_nk * this->_nspin);
278240
#endif
279241
return const_cast<TK*>(this->_DMK[ik].data());
280242
}
@@ -284,7 +246,7 @@ template <typename TK, typename TR>
284246
void DensityMatrix<TK, TR>::set_DMK_pointer(const int ik, TK* DMK_in)
285247
{
286248
#ifdef __DEBUG
287-
assert(ik < this->_nks * this->_nspin);
249+
assert(ik < this->_nk * this->_nspin);
288250
#endif
289251
this->_DMK[ik].assign(DMK_in, DMK_in + this->_paraV->nrow * this->_paraV->ncol);
290252
}
@@ -295,17 +257,17 @@ void DensityMatrix<TK, TR>::set_DMK(const int ispin, const int ik, const int i,
295257
{
296258
#ifdef __DEBUG
297259
assert(ispin > 0 && ispin <= this->_nspin);
298-
assert(ik >= 0 && ik < this->_nks);
260+
assert(ik >= 0 && ik < this->_nk);
299261
#endif
300262
// consider transpose col=>row
301-
this->_DMK[ik + this->_nks * (ispin - 1)][i * this->_paraV->nrow + j] = value;
263+
this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->nrow + j] = value;
302264
}
303265

304266
// set _DMK element
305267
template <typename TK, typename TR>
306268
void DensityMatrix<TK, TR>::set_DMK_zero()
307269
{
308-
for (int ik = 0; ik < _nspin * _nks; ik++)
270+
for (int ik = 0; ik < _nspin * _nk; ik++)
309271
{
310272
ModuleBase::GlobalFunc::ZEROS(this->_DMK[ik].data(),
311273
this->_paraV->get_row_size() * this->_paraV->get_col_size());
@@ -320,18 +282,17 @@ TK DensityMatrix<TK, TR>::get_DMK(const int ispin, const int ik, const int i, co
320282
assert(ispin > 0 && ispin <= this->_nspin);
321283
#endif
322284
// consider transpose col=>row
323-
return this->_DMK[ik + this->_nks * (ispin - 1)][i * this->_paraV->nrow + j];
285+
return this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->nrow + j];
324286
}
325287

326288
// get _DMK nks, nrow, ncol
327289
template <typename TK, typename TR>
328290
int DensityMatrix<TK, TR>::get_DMK_nks() const
329291
{
330292
#ifdef __DEBUG
331-
assert(this->_DMK.size() != 0);
332-
assert(this->_kv != nullptr);
293+
assert(this->_DMK.size() == _nk * _nspin);
333294
#endif
334-
return this->_kv->get_nks();
295+
return _nk * _nspin;
335296
}
336297

337298
template <typename TK, typename TR>
@@ -403,7 +364,7 @@ void DensityMatrix<TK, TR>::cal_DMR_test()
403364
{
404365
for (int is = 1; is <= this->_nspin; ++is)
405366
{
406-
int ik_begin = this->_nks * (is - 1); // jump this->_nks for spin_down if nspin==2
367+
int ik_begin = this->_nk * (is - 1); // jump this->_nk for spin_down if nspin==2
407368
hamilt::HContainer<TR>* tmp_DMR = this->_DMR[is - 1];
408369
// set zero since this function is called in every scf step
409370
tmp_DMR->set_zero();
@@ -435,12 +396,12 @@ void DensityMatrix<TK, TR>::cal_DMR_test()
435396
#endif
436397
std::complex<TR> tmp_res;
437398
// loop over k-points
438-
for (int ik = 0; ik < this->_nks; ++ik)
399+
for (int ik = 0; ik < this->_nk; ++ik)
439400
{
440401
// cal k_phase
441402
// if TK==std::complex<double>, kphase is e^{ikR}
442403
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
443-
const double arg = (this->_kv->kvec_d[ik] * dR) * ModuleBase::TWO_PI;
404+
const double arg = (this->_kvec_d[ik] * dR) * ModuleBase::TWO_PI;
444405
double sinp, cosp;
445406
ModuleBase::libm::sincos(arg, &sinp, &cosp);
446407
std::complex<double> kphase = std::complex<double>(cosp, sinp);
@@ -477,7 +438,7 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR()
477438
int ld_hk2 = 2 * ld_hk;
478439
for (int is = 1; is <= this->_nspin; ++is)
479440
{
480-
int ik_begin = this->_nks * (is - 1); // jump this->_nks for spin_down if nspin==2
441+
int ik_begin = this->_nk * (is - 1); // jump this->_nk for spin_down if nspin==2
481442
hamilt::HContainer<double>* tmp_DMR = this->_DMR[is - 1];
482443
// set zero since this function is called in every scf step
483444
tmp_DMR->set_zero();
@@ -515,12 +476,12 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR()
515476
// loop over k-points
516477
if (PARAM.inp.nspin != 4)
517478
{
518-
for (int ik = 0; ik < this->_nks; ++ik)
479+
for (int ik = 0; ik < this->_nk; ++ik)
519480
{
520481
// cal k_phase
521482
// if TK==std::complex<double>, kphase is e^{ikR}
522483
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
523-
const double arg = (this->_kv->kvec_d[ik] * dR) * ModuleBase::TWO_PI;
484+
const double arg = (this->_kvec_d[ik] * dR) * ModuleBase::TWO_PI;
524485
double sinp, cosp;
525486
ModuleBase::libm::sincos(arg, &sinp, &cosp);
526487
std::complex<double> kphase = std::complex<double>(cosp, sinp);
@@ -559,12 +520,12 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR()
559520
if (PARAM.inp.nspin == 4)
560521
{
561522
tmp_DMR.assign(tmp_ap.get_size(), std::complex<double>(0.0, 0.0));
562-
for (int ik = 0; ik < this->_nks; ++ik)
523+
for (int ik = 0; ik < this->_nk; ++ik)
563524
{
564525
// cal k_phase
565526
// if TK==std::complex<double>, kphase is e^{ikR}
566527
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
567-
const double arg = (this->_kv->kvec_d[ik] * dR) * ModuleBase::TWO_PI;
528+
const double arg = (this->_kvec_d[ik] * dR) * ModuleBase::TWO_PI;
568529
double sinp, cosp;
569530
ModuleBase::libm::sincos(arg, &sinp, &cosp);
570531
std::complex<double> kphase = std::complex<double>(cosp, sinp);
@@ -644,7 +605,7 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR(const int ik)
644605
int ld_hk2 = 2 * ld_hk;
645606
for (int is = 1; is <= this->_nspin; ++is)
646607
{
647-
int ik_begin = this->_nks * (is - 1); // jump this->_nks for spin_down if nspin==2
608+
int ik_begin = this->_nk * (is - 1); // jump this->_nk for spin_down if nspin==2
648609
hamilt::HContainer<double>* tmp_DMR = this->_DMR[is - 1];
649610
// set zero since this function is called in every scf step
650611
tmp_DMR->set_zero();
@@ -680,7 +641,7 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR(const int ik)
680641
// cal k_phase
681642
// if TK==std::complex<double>, kphase is e^{ikR}
682643
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
683-
const double arg = (this->_kv->kvec_d[ik] * dR) * ModuleBase::TWO_PI;
644+
const double arg = (this->_kvec_d[ik] * dR) * ModuleBase::TWO_PI;
684645
double sinp, cosp;
685646
ModuleBase::libm::sincos(arg, &sinp, &cosp);
686647
std::complex<double> kphase = std::complex<double>(cosp, sinp);
@@ -734,14 +695,14 @@ void DensityMatrix<double, double>::cal_DMR()
734695
int ld_hk = this->_paraV->nrow;
735696
for (int is = 1; is <= this->_nspin; ++is)
736697
{
737-
int ik_begin = this->_nks * (is - 1); // jump this->_nks for spin_down if nspin==2
698+
int ik_begin = this->_nk * (is - 1); // jump this->_nk for spin_down if nspin==2
738699
hamilt::HContainer<double>* tmp_DMR = this->_DMR[is - 1];
739700
// set zero since this function is called in every scf step
740701
tmp_DMR->set_zero();
741702

742703
#ifdef __DEBUG
743704
// assert(tmp_DMR->is_gamma_only() == true);
744-
assert(this->_nks == 1);
705+
assert(this->_nk == 1);
745706
#endif
746707
#ifdef _OPENMP
747708
#pragma omp parallel for
@@ -857,9 +818,9 @@ void DensityMatrix<TK, TR>::read_DMK(const std::string directory, const int ispi
857818
// quit the program or not.
858819
bool quit = false;
859820

860-
ModuleBase::CHECK_DOUBLE(ifs, this->_kv->kvec_d[ik].x, quit);
861-
ModuleBase::CHECK_DOUBLE(ifs, this->_kv->kvec_d[ik].y, quit);
862-
ModuleBase::CHECK_DOUBLE(ifs, this->_kv->kvec_d[ik].z, quit);
821+
ModuleBase::CHECK_DOUBLE(ifs, this->_kvec_d[ik].x, quit);
822+
ModuleBase::CHECK_DOUBLE(ifs, this->_kvec_d[ik].y, quit);
823+
ModuleBase::CHECK_DOUBLE(ifs, this->_kvec_d[ik].z, quit);
863824
ModuleBase::CHECK_INT(ifs, this->_paraV->nrow);
864825
ModuleBase::CHECK_INT(ifs, this->_paraV->ncol);
865826
} // If file exist, read in data.
@@ -869,7 +830,7 @@ void DensityMatrix<TK, TR>::read_DMK(const std::string directory, const int ispi
869830
{
870831
for (int j = 0; j < this->_paraV->ncol; ++j)
871832
{
872-
ifs >> this->_DMK[ik + this->_nks * (ispin - 1)][i * this->_paraV->ncol + j];
833+
ifs >> this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->ncol + j];
873834
}
874835
}
875836
ifs.close();
@@ -892,7 +853,7 @@ void DensityMatrix<double, double>::write_DMK(const std::string directory, const
892853
{
893854
ModuleBase::WARNING("elecstate::write_dmk", "Can't create DENSITY MATRIX File!");
894855
}
895-
ofs << this->_kv->kvec_d[ik].x << " " << this->_kv->kvec_d[ik].y << " " << this->_kv->kvec_d[ik].z << std::endl;
856+
ofs << this->_kvec_d[ik].x << " " << this->_kvec_d[ik].y << " " << this->_kvec_d[ik].z << std::endl;
896857
ofs << "\n " << this->_paraV->nrow << " " << this->_paraV->ncol << std::endl;
897858

898859
ofs << std::setprecision(3);
@@ -906,7 +867,7 @@ void DensityMatrix<double, double>::write_DMK(const std::string directory, const
906867
{
907868
ofs << "\n";
908869
}
909-
ofs << " " << this->_DMK[ik + this->_nks * (ispin - 1)][i * this->_paraV->ncol + j];
870+
ofs << " " << this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->ncol + j];
910871
}
911872
}
912873

@@ -929,7 +890,7 @@ void DensityMatrix<std::complex<double>, double>::write_DMK(const std::string di
929890
{
930891
ModuleBase::WARNING("elecstate::write_dmk", "Can't create DENSITY MATRIX File!");
931892
}
932-
ofs << this->_kv->kvec_d[ik].x << " " << this->_kv->kvec_d[ik].y << " " << this->_kv->kvec_d[ik].z << std::endl;
893+
ofs << this->_kvec_d[ik].x << " " << this->_kvec_d[ik].y << " " << this->_kvec_d[ik].z << std::endl;
933894
ofs << "\n " << this->_paraV->nrow << " " << this->_paraV->ncol << std::endl;
934895

935896
ofs << std::setprecision(3);
@@ -943,7 +904,7 @@ void DensityMatrix<std::complex<double>, double>::write_DMK(const std::string di
943904
{
944905
ofs << "\n";
945906
}
946-
ofs << " " << this->_DMK[ik + this->_nks * (ispin - 1)][i * this->_paraV->ncol + j].real();
907+
ofs << " " << this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->ncol + j].real();
947908
}
948909
}
949910

source/module_elecstate/module_dm/density_matrix.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include <string>
55

6-
#include "module_cell/klist.h"
76
#include "module_cell/module_neighbor/sltk_grid_driver.h"
87
#include "module_hamilt_lcao/hamilt_lcaodft/record_adj.h"
98
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
@@ -45,16 +44,20 @@ class DensityMatrix
4544

4645
/**
4746
* @brief Constructor of class DensityMatrix for multi-k calculation
48-
* @param _kv pointer of K_Vectors object
4947
* @param _paraV pointer of Parallel_Orbitals object
50-
* @param nspin spin setting (1 - none spin; 2 - spin; 4 - SOC)
48+
* @param nspin number of spin of the density matrix, set by user according to global nspin
49+
* (usually {nspin_global -> nspin_dm} = {1->1, 2->2, 4->1}, but sometimes 2->1 like in LR-TDDFT)
50+
* @param kvec_d direct coordinates of kpoints
51+
* @param nk number of k-points, not always equal to K_Vectors::get_nks()/nspin_dm.
52+
* it will be set to kvec_d.size() if the value is invalid
5153
*/
52-
DensityMatrix(const K_Vectors* _kv, const Parallel_Orbitals* _paraV, const int nspin);
54+
DensityMatrix(const Parallel_Orbitals* _paraV, const int nspin, const std::vector<ModuleBase::Vector3<double>>& kvec_d, const int nk);
5355

5456
/**
5557
* @brief Constructor of class DensityMatrix for gamma-only calculation, where kvector is not required
5658
* @param _paraV pointer of Parallel_Orbitals object
57-
* @param nspin spin setting (1 - none spin; 2 - spin; 4 - SOC)
59+
* @param nspin number of spin of the density matrix, set by user according to global nspin
60+
* (usually {nspin_global -> nspin_dm} = {1->1, 2->2, 4->1}, but sometimes 2->1 like in LR-TDDFT)
5861
*/
5962
DensityMatrix(const Parallel_Orbitals* _paraV, const int nspin);
6063

@@ -169,7 +172,7 @@ class DensityMatrix
169172
*/
170173
const Parallel_Orbitals* get_paraV_pointer() const {return this->_paraV;}
171174

172-
const K_Vectors* get_kv_pointer() const {return this->_kv;}
175+
const std::vector<ModuleBase::Vector3<double>>& get_kvec_d() const { return this->_kvec_d; }
173176

174177
/**
175178
* @brief calculate density matrix DMR from dm(k) using blas::axpy
@@ -240,16 +243,16 @@ class DensityMatrix
240243

241244
/**
242245
* @brief density matrix in k space, which is a vector[ik]
243-
* DMK should be a [_nspin][_nks][i][j] matrix,
244-
* whose size is _nspin * _nks * _paraV->get_nrow() * _paraV->get_ncol()
246+
* DMK should be a [_nspin][_nk][i][j] matrix,
247+
* whose size is _nspin * _nk * _paraV->get_nrow() * _paraV->get_ncol()
245248
*/
246249
// std::vector<ModuleBase::ComplexMatrix> _DMK;
247250
std::vector<std::vector<TK>> _DMK;
248251

249252
/**
250253
* @brief K_Vectors object, which is used to get k-point information
251254
*/
252-
const K_Vectors* _kv;
255+
const std::vector<ModuleBase::Vector3<double>> _kvec_d;
253256

254257
/**
255258
* @brief Parallel_Orbitals object, which contain all information of 2D block cyclic distribution
@@ -265,10 +268,10 @@ class DensityMatrix
265268

266269
/**
267270
* @brief real number of k-points
268-
* _nks is not equal to _kv->get_nks() when spin-polarization is considered
269-
* _nks = kv->_nks / nspin
271+
* _nk is not equal to _kv->get_nks() when spin-polarization is considered
272+
* _nk = kv->get_nks() / nspin when nspin=2
270273
*/
271-
int _nks = 0;
274+
int _nk = 0;
272275

273276

274277
};

0 commit comments

Comments
 (0)