@@ -31,36 +31,55 @@ ElecStatePW<T, Device>::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in,
3131template <typename T, typename Device>
3232ElecStatePW<T, Device>::~ElecStatePW ()
3333{
34- if (base_device::get_device_type<Device>( this -> ctx ) == base_device::GpuDevice )
34+ if (PARAM. inp . device == " gpu " || PARAM. inp . precision == " single " )
3535 {
3636 delmem_var_op ()(this ->ctx , this ->rho_data );
37+ delete[] this ->rho ;
38+
39+ if (PARAM.globalv .double_grid || PARAM.globalv .use_uspp )
40+ {
41+ delmem_complex_op ()(this ->ctx , this ->rhog_data );
42+ delete[] this ->rhog ;
43+ }
3744 if (get_xc_func_type () == 3 || PARAM.inp .out_elf [0 ] > 0 )
3845 {
3946 delmem_var_op ()(this ->ctx , this ->kin_r_data );
47+ delete[] this ->kin_r ;
4048 }
4149 }
42- if (PARAM.inp . device == " gpu " || PARAM. inp . precision == " single " ) {
43- delete[] this -> rho ;
44- delete[] this ->kin_r ;
50+ if (PARAM.globalv . use_uspp )
51+ {
52+ delmem_var_op ()( this -> ctx , this ->becsum ) ;
4553 }
46- delmem_var_op ()(this ->ctx , becsum);
4754 delmem_complex_op ()(this ->ctx , this ->wfcr );
4855 delmem_complex_op ()(this ->ctx , this ->wfcr_another_spin );
4956}
5057
5158template <typename T, typename Device>
5259void ElecStatePW<T, Device>::init_rho_data()
5360{
54- if (this ->init_rho ) {
61+ if (this ->init_rho )
62+ {
5563 return ;
5664 }
57-
58- if (PARAM.inp .device == " gpu" || PARAM.inp .precision == " single" ) {
65+
66+ if (PARAM.inp .device == " gpu" || PARAM.inp .precision == " single" )
67+ {
5968 this ->rho = new Real*[this ->charge ->nspin ];
6069 resmem_var_op ()(this ->ctx , this ->rho_data , this ->charge ->nspin * this ->charge ->nrxx );
61- for (int ii = 0 ; ii < this ->charge ->nspin ; ii++) {
70+ for (int ii = 0 ; ii < this ->charge ->nspin ; ii++)
71+ {
6272 this ->rho [ii] = this ->rho_data + ii * this ->charge ->nrxx ;
6373 }
74+ if (PARAM.globalv .double_grid || PARAM.globalv .use_uspp )
75+ {
76+ this ->rhog = new T*[this ->charge ->nspin ];
77+ resmem_complex_op ()(this ->ctx , this ->rhog_data , this ->charge ->nspin * this ->charge ->rhopw ->npw );
78+ for (int ii = 0 ; ii < this ->charge ->nspin ; ii++)
79+ {
80+ this ->rhog [ii] = this ->rhog_data + ii * this ->charge ->rhopw ->npw ;
81+ }
82+ }
6483 if (get_xc_func_type () == 3 || PARAM.inp .out_elf [0 ] > 0 )
6584 {
6685 this ->kin_r = new Real*[this ->charge ->nspin ];
@@ -70,8 +89,13 @@ void ElecStatePW<T, Device>::init_rho_data()
7089 }
7190 }
7291 }
73- else {
92+ else
93+ {
7494 this ->rho = reinterpret_cast <Real **>(this ->charge ->rho );
95+ if (PARAM.globalv .double_grid || PARAM.globalv .use_uspp )
96+ {
97+ this ->rhog = reinterpret_cast <T**>(this ->charge ->rhog );
98+ }
7599 if (get_xc_func_type () == 3 || PARAM.inp .out_elf [0 ] > 0 )
76100 {
77101 this ->kin_r = reinterpret_cast <Real **>(this ->charge ->kin_r );
@@ -100,19 +124,24 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
100124 // ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
101125 setmem_var_op ()(this ->ctx , this ->kin_r [is], 0 , this ->charge ->nrxx );
102126 }
103- }
127+ if (PARAM.globalv .double_grid || PARAM.globalv .use_uspp )
128+ {
129+ setmem_complex_op ()(this ->ctx , this ->rhog [is], 0 , this ->charge ->rhopw ->npw );
130+ }
131+ }
104132
105133 for (int ik = 0 ; ik < psi.get_nk (); ++ik)
106134 {
107135 psi.fix_k (ik);
108136 this ->updateRhoK (psi);
109137 }
110- if (PARAM.globalv .use_uspp )
138+
139+ this ->add_usrho (psi);
140+
141+ if (PARAM.inp .device == " gpu" || PARAM.inp .precision == " single" )
111142 {
112- this ->add_usrho (psi);
113- }
114- if (PARAM.inp .device == " gpu" || PARAM.inp .precision == " single" ) {
115- for (int ii = 0 ; ii < PARAM.inp .nspin ; ii++) {
143+ for (int ii = 0 ; ii < PARAM.inp .nspin ; ii++)
144+ {
116145 castmem_var_d2h_op ()(cpu_ctx, this ->ctx , this ->charge ->rho [ii], this ->rho [ii], this ->charge ->nrxx );
117146 if (get_xc_func_type () == 3 )
118147 {
@@ -397,32 +426,39 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
397426template <typename T, typename Device>
398427void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
399428{
400- this ->cal_becsum (psi);
429+ if (PARAM.globalv .use_uspp )
430+ {
431+ this ->cal_becsum (psi);
432+ }
401433
402434 // transform soft charge to recip space using smooth grids
403- T* rhog = nullptr ;
404- resmem_complex_op ()(this ->ctx , rhog, this ->charge ->rhopw ->npw * PARAM.inp .nspin , " ElecState<PW>::rhog" );
405- setmem_complex_op ()(this ->ctx , rhog, 0 , this ->charge ->rhopw ->npw * PARAM.inp .nspin );
406- for (int is = 0 ; is < PARAM.inp .nspin ; is++)
435+ if (PARAM.globalv .double_grid || PARAM.globalv .use_uspp )
407436 {
408- this ->rhopw_smooth ->real2recip (this ->rho [is], &rhog[is * this ->charge ->rhopw ->npw ]);
437+ for (int is = 0 ; is < PARAM.inp .nspin ; is++)
438+ {
439+ this ->rhopw_smooth ->real2recip (this ->rho [is], this ->rhog [is]);
440+ }
409441 }
410442
411443 // \sum_lm Q_lm(r) \sum_i <psi_i|beta_l><beta_m|psi_i> w_i
412444 // add to the charge density in reciprocal space the part which is due to the US augmentation.
413- this ->addusdens_g (becsum, rhog);
445+ if (PARAM.globalv .use_uspp )
446+ {
447+ this ->addusdens_g (becsum, rhog);
448+ }
414449
415450 // transform back to real space using dense grids
416- for ( int is = 0 ; is < PARAM.inp . nspin ; is++ )
451+ if (PARAM. globalv . double_grid || PARAM.globalv . use_uspp )
417452 {
418- this ->charge ->rhopw ->recip2real (&rhog[is * this ->charge ->rhopw ->npw ], this ->rho [is]);
453+ for (int is = 0 ; is < PARAM.inp .nspin ; is++)
454+ {
455+ this ->charge ->rhopw ->recip2real (this ->rhog [is], this ->rho [is]);
456+ }
419457 }
420-
421- delmem_complex_op ()(this ->ctx , rhog);
422458}
423459
424460template <typename T, typename Device>
425- void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T* rhog)
461+ void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T** rhog)
426462{
427463 const T one{1 , 0 };
428464 const T zero{0 , 0 };
@@ -506,7 +542,7 @@ void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T* rhog)
506542 this ->ppcell ->radial_fft_q (this ->ctx , npw, ih, jh, it, qmod, ylmk0, qgm);
507543 for (int ig = 0 ; ig < npw; ig++)
508544 {
509- rhog[is * npw + ig] += qgm[ig] * aux2[ijh * npw + ig];
545+ rhog[is][ ig] += qgm[ig] * aux2[ijh * npw + ig];
510546 }
511547 ijh++;
512548 }
0 commit comments