11#include " veff_pw.h"
22
33#include " module_base/timer.h"
4- #include " src_pw/global.h"
54#include " module_base/tool_quit.h"
65
76using hamilt::Veff;
@@ -15,14 +14,29 @@ Veff<OperatorPW<FPTYPE, Device>>::Veff(
1514{
1615 this ->cal_type = pw_veff;
1716 this ->isk = isk_in;
18- this ->veff = veff_in;
17+ // this->veff = veff_in;
18+ // TODO: add an GPU veff array
19+ this ->veff = veff_in[0 ].c ;
20+ this ->veff_col = veff_in[0 ].nc ;
1921 this ->wfcpw = wfcpw_in;
20- if ( this ->isk == nullptr || this ->veff == nullptr || this ->wfcpw == nullptr )
21- {
22+ resize_memory_op ()(this ->ctx , this ->porter , this ->wfcpw ->nmaxgr );
23+ if (this ->npol != 1 ) {
24+ resize_memory_op ()(this ->ctx , this ->porter1 , this ->wfcpw ->nmaxgr );
25+ }
26+ if (this ->isk == nullptr || this ->veff == nullptr || this ->wfcpw == nullptr ) {
2227 ModuleBase::WARNING_QUIT (" VeffPW" , " Constuctor of Operator::VeffPW is failed, please check your code!" );
2328 }
2429}
2530
31+ template <typename FPTYPE, typename Device>
32+ Veff<OperatorPW<FPTYPE, Device>>::~Veff ()
33+ {
34+ delete_memory_op ()(this ->ctx , this ->porter );
35+ if (this ->npol != 1 ) {
36+ delete_memory_op ()(this ->ctx , this ->porter1 );
37+ }
38+ }
39+
2640template <typename FPTYPE, typename Device>
2741void Veff<OperatorPW<FPTYPE, Device>>::act(
2842 const psi::Psi<std::complex <FPTYPE>, Device> *psi_in,
@@ -37,65 +51,64 @@ void Veff<OperatorPW<FPTYPE, Device>>::act(
3751 const int current_spin = this ->isk [this ->ik ];
3852 this ->npol = psi_in->npol ;
3953
40- std::complex <FPTYPE> *porter = new std::complex <FPTYPE>[wfcpw->nmaxgr ];
54+ // std::complex<FPTYPE> *porter = new std::complex<FPTYPE>[wfcpw->nmaxgr];
4155 for (int ib = 0 ; ib < n_npwx; ib += this ->npol )
4256 {
4357 if (this ->npol == 1 )
4458 {
45- wfcpw->recip2real (tmpsi_in, porter, this ->ik );
59+ // wfcpw->recip2real(tmpsi_in, porter, this->ik);
60+ wfcpw->recip_to_real (this ->ctx , tmpsi_in, this ->porter , this ->ik );
4661 // NOTICE: when MPI threads are larger than number of Z grids
4762 // veff would contain nothing, and nothing should be done in real space
4863 // but the 3DFFT can not be skipped, it will cause hanging
49- if (this ->veff -> nc != 0 )
64+ if (this ->veff_col != 0 )
5065 {
51- const FPTYPE* current_veff = &(this ->veff [0 ](current_spin, 0 ));
52- for (int ir = 0 ; ir < this ->veff ->nc ; ++ir)
53- {
54- porter[ir] *= current_veff[ir];
55- }
66+ // const FPTYPE* current_veff = &(this->veff[0](current_spin, 0));
67+ // for (int ir = 0; ir < this->veff->nc; ++ir)
68+ // {
69+ // porter[ir] *= current_veff[ir];
70+ // }
71+ veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->veff + current_spin * this ->veff_col );
5672 }
57- wfcpw->real2recip (porter, tmhpsi, this ->ik , true );
73+ // wfcpw->real2recip(porter, tmhpsi, this->ik, true);
74+ wfcpw->real_to_recip (this ->ctx , this ->porter , tmhpsi, this ->ik , true );
5875 }
5976 else
6077 {
61- std::complex <FPTYPE> *porter1 = new std::complex <FPTYPE>[wfcpw->nmaxgr ];
78+ // std::complex<FPTYPE> *porter1 = new std::complex<FPTYPE>[wfcpw->nmaxgr];
6279 // fft to real space and doing things.
63- wfcpw->recip2real (tmpsi_in, porter, this ->ik );
64- wfcpw->recip2real (tmpsi_in + this ->max_npw , porter1, this ->ik );
80+ wfcpw->recip2real (tmpsi_in, this -> porter , this ->ik );
81+ wfcpw->recip2real (tmpsi_in + this ->max_npw , this -> porter1 , this ->ik );
6582 std::complex <FPTYPE> sup, sdown;
66- if (this ->veff -> nc != 0 )
83+ if (this ->veff_col != 0 )
6784 {
6885 const FPTYPE* current_veff[4 ];
6986 for (int is=0 ;is<4 ;is++)
7087 {
71- current_veff[is] = &( this ->veff [ 0 ](is, 0 )) ;
88+ current_veff[is] = this ->veff + is * this -> veff_col ;
7289 }
73- for (int ir = 0 ; ir < this ->veff -> nc ; ir++)
90+ for (int ir = 0 ; ir < this ->veff_col ; ir++)
7491 {
75- sup = porter[ir] * (current_veff[0 ][ir] + current_veff[3 ][ir])
76- + porter1[ir]
92+ sup = this -> porter [ir] * (current_veff[0 ][ir] + current_veff[3 ][ir])
93+ + this -> porter1 [ir]
7794 * (current_veff[1 ][ir]
7895 - std::complex <FPTYPE>(0.0 , 1.0 ) * current_veff[2 ][ir]);
79- sdown = porter1[ir] * (current_veff[0 ][ir] - current_veff[3 ][ir])
80- + porter[ir]
96+ sdown = this -> porter1 [ir] * (current_veff[0 ][ir] - current_veff[3 ][ir])
97+ + this -> porter [ir]
8198 * (current_veff[1 ][ir]
8299 + std::complex <FPTYPE>(0.0 , 1.0 ) * current_veff[2 ][ir]);
83- porter[ir] = sup;
84- porter1[ir] = sdown;
100+ this -> porter [ir] = sup;
101+ this -> porter1 [ir] = sdown;
85102 }
86103 }
87104 // (3) fft back to G space.
88- wfcpw->real2recip (porter, tmhpsi, this ->ik , true );
89- wfcpw->real2recip (porter1, tmhpsi + this ->max_npw , this ->ik , true );
90-
91- delete[] porter1;
105+ wfcpw->real2recip (this ->porter , tmhpsi, this ->ik , true );
106+ wfcpw->real2recip (this ->porter1 , tmhpsi + this ->max_npw , this ->ik , true );
92107 }
93108 tmhpsi += this ->max_npw * this ->npol ;
94109 tmpsi_in += this ->max_npw * this ->npol ;
95110 }
96- delete[] porter;
97111 ModuleBase::timer::tick (" Operator" , " VeffPW" );
98- return ;
99112}
100113
101114namespace hamilt {
0 commit comments