Skip to content

Commit 83fc5c0

Browse files
maki49Fisherd99
authored andcommitted
Refactor: generalize the transition density matrix in module_lr (deepmodeling#5852)
* generalize dm_trans * pass factor instead of bool renorm_k * fix bug * update UT
1 parent e746ebe commit 83fc5c0

File tree

10 files changed

+257
-162
lines changed

10 files changed

+257
-162
lines changed

source/module_lr/ao_to_mo_transformer/ao_to_mo.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@
77
#endif
88
namespace LR
99
{
10+
#ifndef MO_TYPE_H
11+
#define MO_TYPE_H
1012
enum MO_TYPE { OO, VO, VV };
13+
#endif
1114
template<typename T>
1215
void ao_to_mo_forloop_serial(
1316
const std::vector<container::Tensor>& mat_ao,
1417
const psi::Psi<T>& coeff,
1518
const int& nocc,
1619
const int& nvirt,
1720
T* const mat_mo,
18-
MO_TYPE type = VO);
21+
const MO_TYPE type = VO);
1922
template<typename T>
2023
void ao_to_mo_blas(
2124
const std::vector<container::Tensor>& mat_ao,
@@ -24,7 +27,7 @@ namespace LR
2427
const int& nvirt,
2528
T* const mat_mo,
2629
const bool add_on = true,
27-
MO_TYPE type = VO);
30+
const MO_TYPE type = VO);
2831
#ifdef __MPI
2932
template<typename T>
3033
void ao_to_mo_pblas(
@@ -38,6 +41,6 @@ namespace LR
3841
const Parallel_2D& pmat_mo,
3942
T* const mat_mo,
4043
const bool add_on = true,
41-
MO_TYPE type = VO);
44+
const MO_TYPE type = VO);
4245
#endif
4346
}

source/module_lr/ao_to_mo_transformer/ao_to_mo_parallel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace LR
2121
const Parallel_2D& pmat_mo,
2222
double* mat_mo,
2323
const bool add_on,
24-
MO_TYPE type)
24+
const MO_TYPE type)
2525
{
2626
ModuleBase::TITLE("hamilt_lrtd", "ao_to_mo_pblas");
2727
assert(pmat_ao.comm() == pcoeff.comm() && pmat_ao.comm() == pmat_mo.comm());
@@ -79,7 +79,7 @@ namespace LR
7979
const Parallel_2D& pmat_mo,
8080
std::complex<double>* const mat_mo,
8181
const bool add_on,
82-
MO_TYPE type)
82+
const MO_TYPE type)
8383
{
8484
ModuleBase::TITLE("hamilt_lrtd", "cal_AX_plas");
8585
assert(pmat_ao.comm() == pcoeff.comm() && pmat_ao.comm() == pmat_mo.comm());

source/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace LR
1111
const int& nocc,
1212
const int& nvirt,
1313
double* mat_mo,
14-
MO_TYPE type)
14+
const MO_TYPE type)
1515
{
1616
ModuleBase::TITLE("hamilt_lrtd", "ao_to_mo_forloop_serial");
1717
const int nks = mat_ao.size();
@@ -49,7 +49,7 @@ namespace LR
4949
const int& nocc,
5050
const int& nvirt,
5151
std::complex<double>* const mat_mo,
52-
MO_TYPE type)
52+
const MO_TYPE type)
5353
{
5454
ModuleBase::TITLE("hamilt_lrtd", "ao_to_mo_forloop_serial");
5555
const int nks = mat_ao.size();
@@ -88,7 +88,7 @@ namespace LR
8888
const int& nvirt,
8989
double* mat_mo,
9090
const bool add_on,
91-
MO_TYPE type)
91+
const MO_TYPE type)
9292
{
9393
ModuleBase::TITLE("hamilt_lrtd", "ao_to_mo_blas");
9494
const int nks = mat_ao.size();
@@ -129,7 +129,7 @@ namespace LR
129129
const int& nvirt,
130130
std::complex<double>* const mat_mo,
131131
const bool add_on,
132-
MO_TYPE type)
132+
const MO_TYPE type)
133133
{
134134
ModuleBase::TITLE("hamilt_lrtd", "ao_to_mo_blas");
135135
const int nks = mat_ao.size();

source/module_lr/dm_trans/dm_trans.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
#endif
99
namespace LR
1010
{
11-
// use templates in the future.
11+
12+
#ifndef MO_TYPE_H
13+
#define MO_TYPE_H
14+
enum MO_TYPE { OO, VO, VV };
15+
#endif
16+
1217
#ifdef __MPI
1318
/// @brief calculate the 2d-block transition density matrix in AO basis using p?gemm
1419
/// \f[ \tilde{\rho}_{\mu_j\mu_b}=\sum_{jb}c_{j,\mu_j}X_{jb}c^*_{b,\mu_b} \f]
@@ -22,8 +27,8 @@ namespace LR
2227
const int nocc,
2328
const int nvirt,
2429
const Parallel_2D& pmat,
25-
const bool renorm_k = true,
26-
const int nspin = 1);
30+
const T factor = (T)1.0,
31+
const MO_TYPE type = MO_TYPE::VO);
2732
#endif
2833

2934
/// @brief calculate the 2d-block transition density matrix in AO basis using ?gemm
@@ -32,8 +37,8 @@ namespace LR
3237
const T* const X_istate,
3338
const psi::Psi<T>& c,
3439
const int& nocc, const int& nvirt,
35-
const bool renorm_k = true,
36-
const int nspin = 1);
40+
const T factor = (T)1.0,
41+
const MO_TYPE type = MO_TYPE::VO);
3742

3843
// for test
3944
/// @brief calculate the 2d-block transition density matrix in AO basis using for loop (for test)
@@ -42,6 +47,6 @@ namespace LR
4247
const T* const X_istate,
4348
const psi::Psi<T>& c,
4449
const int& nocc, const int& nvirt,
45-
const bool renorm_k = true,
46-
const int nspin = 1);
50+
const T factor = (T)1.0,
51+
const MO_TYPE type = MO_TYPE::VO);
4752
}

source/module_lr/dm_trans/dm_trans_parallel.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,44 +18,49 @@ std::vector<container::Tensor> cal_dm_trans_pblas(const double* const X_istate,
1818
const int nocc,
1919
const int nvirt,
2020
const Parallel_2D& pmat,
21-
const bool renorm_k,
22-
const int nspin)
21+
const double factor,
22+
const MO_TYPE type)
2323
{
2424
ModuleBase::TITLE("hamilt_lrtd", "cal_dm_trans_pblas");
2525
assert(px.comm() == pc.comm() && px.comm() == pmat.comm());
2626
assert(px.blacs_ctxt == pc.blacs_ctxt && px.blacs_ctxt == pmat.blacs_ctxt);
2727
assert(pmat.get_local_size() > 0);
2828

2929
const int nks = c.get_nk();
30+
const int i1 = 1;
31+
const int ivirt = nocc + 1;
32+
const int nmo1 = type == MO_TYPE::VV ? nvirt : nocc;
33+
const int nmo2 = type == MO_TYPE::OO ? nocc : nvirt;
34+
const int imo1 = type == MO_TYPE::VV ? ivirt : i1;
35+
const int imo2 = type == MO_TYPE::OO ? i1 : ivirt;
3036

3137
std::vector<container::Tensor> dm_trans(nks,
3238
container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { pmat.get_col_size(), pmat.get_row_size() }));
3339
for (int isk = 0; isk < nks; ++isk)
3440
{
3541
c.fix_k(isk);
3642
const int x_start = isk * px.get_local_size();
37-
int i1 = 1;
38-
int ivirt = nocc + 1;
43+
3944
char transa = 'N';
4045
char transb = 'T';
4146
const double alpha = 1.0;
4247
const double beta = 0;
4348

4449
// 1. [X*C_occ^T]^T=C_occ*X^T
4550
Parallel_2D pXc; // nvirt*naos
46-
LR_Util::setup_2d_division(pXc, px.get_block_size(), naos, nvirt, px.blacs_ctxt);
51+
LR_Util::setup_2d_division(pXc, px.get_block_size(), naos, nmo2, px.blacs_ctxt);
4752
container::Tensor Xc(DAT::DT_DOUBLE,
4853
DEV::CpuDevice,
4954
{pXc.get_col_size(), pXc.get_row_size()}); // row is "inside"(memory contiguity) for pblas
5055
Xc.zero();
51-
pdgemm_(&transa, &transb, &naos, &nvirt, &nocc,
52-
&alpha, c.get_pointer(), &i1, &i1, pc.desc,
56+
pdgemm_(&transa, &transb, &naos, &nmo2, &nmo1,
57+
&alpha, c.get_pointer(), &i1, &imo1, pc.desc,
5358
X_istate + x_start, &i1, &i1, px.desc,
5459
&beta, Xc.data<double>(), &i1, &i1, pXc.desc);
5560

5661
// 2. C_virt*[X*C_occ^T]
57-
pdgemm_(&transa, &transb, &naos, &naos, &nvirt,
58-
&alpha, c.get_pointer(), &i1, &ivirt, pc.desc,
62+
pdgemm_(&transa, &transb, &naos, &naos, &nmo2,
63+
&factor, c.get_pointer(), &i1, &imo2, pc.desc,
5964
Xc.data<double>(), &i1, &i1, pXc.desc,
6065
&beta, dm_trans[isk].data<double>(), &i1, &i1, pmat.desc);
6166
}
@@ -70,23 +75,27 @@ std::vector<container::Tensor> cal_dm_trans_pblas(const std::complex<double>* co
7075
const int nocc,
7176
const int nvirt,
7277
const Parallel_2D& pmat,
73-
const bool renorm_k,
74-
const int nspin)
78+
const std::complex<double> factor,
79+
const MO_TYPE type)
7580
{
7681
ModuleBase::TITLE("hamilt_lrtd", "cal_dm_trans_pblas");
7782
assert(px.comm() == pc.comm() && px.comm() == pmat.comm());
7883
assert(px.blacs_ctxt == pc.blacs_ctxt && px.blacs_ctxt == pmat.blacs_ctxt);
7984
assert(pmat.get_local_size() > 0);
8085
const int nks = c.get_nk();
86+
const int i1 = 1;
87+
const int ivirt = nocc + 1;
88+
const int nmo1 = type == MO_TYPE::VV ? nvirt : nocc;
89+
const int nmo2 = type == MO_TYPE::OO ? nocc : nvirt;
90+
const int imo1 = type == MO_TYPE::VV ? ivirt : i1;
91+
const int imo2 = type == MO_TYPE::OO ? i1 : ivirt;
8192

8293
std::vector<container::Tensor> dm_trans(nks,
8394
container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, {pmat.get_col_size(), pmat.get_row_size()}));
8495
for (int isk = 0; isk < nks; ++isk)
8596
{
8697
c.fix_k(isk);
8798
const int x_start = isk * px.get_local_size();
88-
int i1 = 1;
89-
int ivirt = nocc + 1;
9099

91100
// ============== C_virt * X * C_occ^\dagger=============
92101
// char transa = 'N';
@@ -114,24 +123,23 @@ std::vector<container::Tensor> cal_dm_trans_pblas(const std::complex<double>* co
114123
char transa = 'N';
115124
char transb = 'C';
116125
Parallel_2D pXc;
117-
LR_Util::setup_2d_division(pXc, px.get_block_size(), nvirt, naos, px.blacs_ctxt);
126+
LR_Util::setup_2d_division(pXc, px.get_block_size(), nmo2, naos, px.blacs_ctxt);
118127
container::Tensor Xc(DAT::DT_COMPLEX_DOUBLE,
119128
DEV::CpuDevice,
120129
{pXc.get_col_size(), pXc.get_row_size()}); // row is "inside"(memory contiguity) for pblas
121130
Xc.zero();
122-
std::complex<double> alpha(1.0, 0.0);
131+
const std::complex<double> alpha(1.0, 0.0);
123132
const std::complex<double> beta(0.0, 0.0);
124-
pzgemm_(&transa, &transb, &nvirt, &naos, &nocc, &alpha,
133+
pzgemm_(&transa, &transb, &nmo2, &naos, &nmo1, &alpha,
125134
X_istate + x_start, &i1, &i1, px.desc,
126-
c.get_pointer(), &i1, &i1, pc.desc,
135+
c.get_pointer(), &i1, &imo1, pc.desc,
127136
&beta, Xc.data<std::complex<double>>(), &i1, &i1, pXc.desc);
128137

129138
// 2. [X*C_occ^\dagger]^TC_virt^T
130-
alpha.real(renorm_k ? 1.0 / static_cast<double>(nks) : 1.0);
131139
transa = transb = 'T';
132-
pzgemm_(&transa, &transb, &naos, &naos, &nvirt,
133-
&alpha, Xc.data<std::complex<double>>(), &i1, &i1, pXc.desc,
134-
c.get_pointer(), &i1, &ivirt, pc.desc,
140+
pzgemm_(&transa, &transb, &naos, &naos, &nmo2,
141+
&factor, Xc.data<std::complex<double>>(), &i1, &i1, pXc.desc,
142+
c.get_pointer(), &i1, &imo2, pc.desc,
135143
&beta, dm_trans[isk].data<std::complex<double>>(), &i1, &i1, pmat.desc);
136144
}
137145
return dm_trans;

0 commit comments

Comments
 (0)