@@ -18,44 +18,49 @@ std::vector<container::Tensor> cal_dm_trans_pblas(const double* const X_istate,
1818 const int nocc,
1919 const int nvirt,
2020 const Parallel_2D& pmat,
21- const bool renorm_k ,
22- const int nspin )
21+ const double factor ,
22+ const MO_TYPE type )
2323{
2424 ModuleBase::TITLE (" hamilt_lrtd" , " cal_dm_trans_pblas" );
2525 assert (px.comm () == pc.comm () && px.comm () == pmat.comm ());
2626 assert (px.blacs_ctxt == pc.blacs_ctxt && px.blacs_ctxt == pmat.blacs_ctxt );
2727 assert (pmat.get_local_size () > 0 );
2828
2929 const int nks = c.get_nk ();
30+ const int i1 = 1 ;
31+ const int ivirt = nocc + 1 ;
32+ const int nmo1 = type == MO_TYPE::VV ? nvirt : nocc;
33+ const int nmo2 = type == MO_TYPE::OO ? nocc : nvirt;
34+ const int imo1 = type == MO_TYPE::VV ? ivirt : i1;
35+ const int imo2 = type == MO_TYPE::OO ? i1 : ivirt;
3036
3137 std::vector<container::Tensor> dm_trans (nks,
3238 container::Tensor (DAT::DT_DOUBLE, DEV::CpuDevice, { pmat.get_col_size (), pmat.get_row_size () }));
3339 for (int isk = 0 ; isk < nks; ++isk)
3440 {
3541 c.fix_k (isk);
3642 const int x_start = isk * px.get_local_size ();
37- int i1 = 1 ;
38- int ivirt = nocc + 1 ;
43+
3944 char transa = ' N' ;
4045 char transb = ' T' ;
4146 const double alpha = 1.0 ;
4247 const double beta = 0 ;
4348
4449 // 1. [X*C_occ^T]^T=C_occ*X^T
4550 Parallel_2D pXc; // nvirt*naos
46- LR_Util::setup_2d_division (pXc, px.get_block_size (), naos, nvirt , px.blacs_ctxt );
51+ LR_Util::setup_2d_division (pXc, px.get_block_size (), naos, nmo2 , px.blacs_ctxt );
4752 container::Tensor Xc (DAT::DT_DOUBLE,
4853 DEV::CpuDevice,
4954 {pXc.get_col_size (), pXc.get_row_size ()}); // row is "inside"(memory contiguity) for pblas
5055 Xc.zero ();
51- pdgemm_ (&transa, &transb, &naos, &nvirt , &nocc ,
52- &alpha, c.get_pointer (), &i1, &i1 , pc.desc ,
56+ pdgemm_ (&transa, &transb, &naos, &nmo2 , &nmo1 ,
57+ &alpha, c.get_pointer (), &i1, &imo1 , pc.desc ,
5358 X_istate + x_start, &i1, &i1, px.desc ,
5459 &beta, Xc.data <double >(), &i1, &i1, pXc.desc );
5560
5661 // 2. C_virt*[X*C_occ^T]
57- pdgemm_ (&transa, &transb, &naos, &naos, &nvirt ,
58- &alpha , c.get_pointer (), &i1, &ivirt , pc.desc ,
62+ pdgemm_ (&transa, &transb, &naos, &naos, &nmo2 ,
63+ &factor , c.get_pointer (), &i1, &imo2 , pc.desc ,
5964 Xc.data <double >(), &i1, &i1, pXc.desc ,
6065 &beta, dm_trans[isk].data <double >(), &i1, &i1, pmat.desc );
6166 }
@@ -70,23 +75,27 @@ std::vector<container::Tensor> cal_dm_trans_pblas(const std::complex<double>* co
7075 const int nocc,
7176 const int nvirt,
7277 const Parallel_2D& pmat,
73- const bool renorm_k ,
74- const int nspin )
78+ const std:: complex < double > factor ,
79+ const MO_TYPE type )
7580{
7681 ModuleBase::TITLE (" hamilt_lrtd" , " cal_dm_trans_pblas" );
7782 assert (px.comm () == pc.comm () && px.comm () == pmat.comm ());
7883 assert (px.blacs_ctxt == pc.blacs_ctxt && px.blacs_ctxt == pmat.blacs_ctxt );
7984 assert (pmat.get_local_size () > 0 );
8085 const int nks = c.get_nk ();
86+ const int i1 = 1 ;
87+ const int ivirt = nocc + 1 ;
88+ const int nmo1 = type == MO_TYPE::VV ? nvirt : nocc;
89+ const int nmo2 = type == MO_TYPE::OO ? nocc : nvirt;
90+ const int imo1 = type == MO_TYPE::VV ? ivirt : i1;
91+ const int imo2 = type == MO_TYPE::OO ? i1 : ivirt;
8192
8293 std::vector<container::Tensor> dm_trans (nks,
8394 container::Tensor (DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, {pmat.get_col_size (), pmat.get_row_size ()}));
8495 for (int isk = 0 ; isk < nks; ++isk)
8596 {
8697 c.fix_k (isk);
8798 const int x_start = isk * px.get_local_size ();
88- int i1 = 1 ;
89- int ivirt = nocc + 1 ;
9099
91100 // ============== C_virt * X * C_occ^\dagger=============
92101 // char transa = 'N';
@@ -114,24 +123,23 @@ std::vector<container::Tensor> cal_dm_trans_pblas(const std::complex<double>* co
114123 char transa = ' N' ;
115124 char transb = ' C' ;
116125 Parallel_2D pXc;
117- LR_Util::setup_2d_division (pXc, px.get_block_size (), nvirt , naos, px.blacs_ctxt );
126+ LR_Util::setup_2d_division (pXc, px.get_block_size (), nmo2 , naos, px.blacs_ctxt );
118127 container::Tensor Xc (DAT::DT_COMPLEX_DOUBLE,
119128 DEV::CpuDevice,
120129 {pXc.get_col_size (), pXc.get_row_size ()}); // row is "inside"(memory contiguity) for pblas
121130 Xc.zero ();
122- std::complex <double > alpha (1.0 , 0.0 );
131+ const std::complex <double > alpha (1.0 , 0.0 );
123132 const std::complex <double > beta (0.0 , 0.0 );
124- pzgemm_ (&transa, &transb, &nvirt , &naos, &nocc , &alpha,
133+ pzgemm_ (&transa, &transb, &nmo2 , &naos, &nmo1 , &alpha,
125134 X_istate + x_start, &i1, &i1, px.desc ,
126- c.get_pointer (), &i1, &i1 , pc.desc ,
135+ c.get_pointer (), &i1, &imo1 , pc.desc ,
127136 &beta, Xc.data <std::complex <double >>(), &i1, &i1, pXc.desc );
128137
129138 // 2. [X*C_occ^\dagger]^TC_virt^T
130- alpha.real (renorm_k ? 1.0 / static_cast <double >(nks) : 1.0 );
131139 transa = transb = ' T' ;
132- pzgemm_ (&transa, &transb, &naos, &naos, &nvirt ,
133- &alpha , Xc.data <std::complex <double >>(), &i1, &i1, pXc.desc ,
134- c.get_pointer (), &i1, &ivirt , pc.desc ,
140+ pzgemm_ (&transa, &transb, &naos, &naos, &nmo2 ,
141+ &factor , Xc.data <std::complex <double >>(), &i1, &i1, pXc.desc ,
142+ c.get_pointer (), &i1, &imo2 , pc.desc ,
135143 &beta, dm_trans[isk].data <std::complex <double >>(), &i1, &i1, pmat.desc );
136144 }
137145 return dm_trans;
0 commit comments