1+ #include " op_exx_pw.h"
2+
3+ namespace hamilt
4+ {
5+ template <typename T, typename Device>
6+ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
7+ const int nbasis,
8+ const int npol,
9+ const T *tmpsi_in,
10+ T *tmhpsi,
11+ const int ngk_ik,
12+ const bool is_first_node) const
13+ {
14+ ModuleBase::timer::tick (" OperatorEXXPW" , " act_op_ace" );
15+ // std::cout << "act_op_ace" << std::endl;
16+ // hpsi += -Xi^\dagger * Xi * psi
17+ T* Xi_ace = Xi_ace_k[this ->ik ];
18+ int nbands_tot = psi.get_nbands ();
19+ int nbasis_max = psi.get_nbasis ();
20+ // T* hpsi = nullptr;
21+ // resmem_complex_op()(hpsi, nbands_tot * nbasis);
22+ // setmem_complex_op()(hpsi, 0, nbands_tot * nbasis);
23+ T* Xi_psi = nullptr ;
24+ resmem_complex_op ()(Xi_psi, nbands_tot * nbands);
25+ setmem_complex_op ()(Xi_psi, 0 , nbands_tot * nbands);
26+
27+ char trans_N = ' N' , trans_T = ' T' , trans_C = ' C' ;
28+ T intermediate_one = 1.0 , intermediate_zero = 0.0 , intermediate_minus_one = -1.0 ;
29+ // Xi * psi
30+ gemm_complex_op ()(trans_N,
31+ trans_N,
32+ nbands_tot,
33+ nbands,
34+ nbasis,
35+ &intermediate_one,
36+ Xi_ace,
37+ nbands_tot,
38+ tmpsi_in,
39+ nbasis,
40+ &intermediate_zero,
41+ Xi_psi,
42+ nbands_tot
43+ );
44+
45+ Parallel_Reduce::reduce_pool (Xi_psi, nbands_tot * nbands);
46+
47+ // Xi^\dagger * (Xi * psi)
48+ gemm_complex_op ()(trans_C,
49+ trans_N,
50+ nbasis,
51+ nbands,
52+ nbands_tot,
53+ &intermediate_minus_one,
54+ Xi_ace,
55+ nbands_tot,
56+ Xi_psi,
57+ nbands_tot,
58+ &intermediate_one,
59+ tmhpsi,
60+ nbasis
61+ );
62+
63+
64+ // // negative sign, add to hpsi
65+ // vec_add_vec_complex_op()(this->ctx, nbands * nbasis, tmhpsi, hpsi, -1, tmhpsi, 1);
66+ // delmem_complex_op()(hpsi);
67+ delmem_complex_op ()(Xi_psi);
68+ ModuleBase::timer::tick (" OperatorEXXPW" , " act_op_ace" );
69+
70+ }
71+
72+ template <typename T, typename Device>
73+ void OperatorEXXPW<T, Device>::construct_ace() const
74+ {
75+ ModuleBase::timer::tick (" OperatorEXXPW" , " construct_ace" );
76+ // int nkb = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk();
77+ int nbands = psi.get_nbands ();
78+ int nbasis = psi.get_nbasis ();
79+ int nk = psi.get_nk ();
80+
81+ int ik_save = this ->ik ;
82+ int * ik_ = const_cast <int *>(&this ->ik );
83+
84+ T intermediate_one = 1.0 , intermediate_zero = 0.0 ;
85+
86+ if (h_psi_ace == nullptr )
87+ {
88+ resmem_complex_op ()(h_psi_ace, nbands * nbasis);
89+ setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
90+ }
91+
92+ if (Xi_ace_k.size () != nk)
93+ {
94+ Xi_ace_k.resize (nk);
95+ for (int i = 0 ; i < nk; i++)
96+ {
97+ resmem_complex_op ()(Xi_ace_k[i], nbands * nbasis);
98+ }
99+ }
100+
101+ for (int i = 0 ; i < nk; i++)
102+ {
103+ setmem_complex_op ()(Xi_ace_k[i], 0 , nbands * nbasis);
104+ }
105+
106+ if (L_ace == nullptr )
107+ {
108+ resmem_complex_op ()(L_ace, nbands * nbands);
109+ setmem_complex_op ()(L_ace, 0 , nbands * nbands);
110+ }
111+
112+ if (psi_h_psi_ace == nullptr )
113+ {
114+ resmem_complex_op ()(psi_h_psi_ace, nbands * nbands);
115+ }
116+
117+ if (first_iter) return ;
118+
119+ for (int ik = 0 ; ik < nk; ik++)
120+ {
121+ int npwk = wfcpw->npwk [ik];
122+
123+ T* Xi_ace = Xi_ace_k[ik];
124+ psi.fix_kb (ik, 0 );
125+ T* p_psi = psi.get_pointer ();
126+
127+ setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
128+
129+ *ik_ = ik;
130+
131+ act_op (
132+ nbands,
133+ nbasis,
134+ 1 ,
135+ p_psi,
136+ h_psi_ace,
137+ nbasis,
138+ false
139+ );
140+
141+ // psi_h_psi_ace = psi^\dagger * h_psi_ace
142+ // p_exx_helper->psi.fix_kb(0, 0);
143+ gemm_complex_op ()(' C' ,
144+ ' N' ,
145+ nbands,
146+ nbands,
147+ npwk,
148+ &intermediate_one,
149+ p_psi,
150+ nbasis,
151+ h_psi_ace,
152+ nbasis,
153+ &intermediate_zero,
154+ psi_h_psi_ace,
155+ nbands);
156+
157+ // reduction of psi_h_psi_ace, due to distributed memory
158+ Parallel_Reduce::reduce_pool (psi_h_psi_ace, nbands * nbands);
159+
160+ T intermediate_minus_one = -1.0 ;
161+ axpy_complex_op ()(nbands * nbands,
162+ &intermediate_minus_one,
163+ psi_h_psi_ace,
164+ 1 ,
165+ L_ace,
166+ 1 );
167+
168+
169+ int info = 0 ;
170+ char up = ' U' , lo = ' L' ;
171+
172+ lapack_potrf ()(lo, nbands, L_ace, nbands);
173+
174+ // expand for-loop
175+ for (int i = 0 ; i < nbands; ++i) {
176+ setmem_complex_op ()(L_ace + i * nbands, 0 , i);
177+ }
178+
179+ // L_ace inv in place
180+ char non = ' N' ;
181+ lapack_trtri ()(lo, non, nbands, L_ace, nbands);
182+
183+ // Xi_ace = L_ace^-1 * h_psi_ace^dagger
184+ gemm_complex_op ()(' N' ,
185+ ' C' ,
186+ nbands,
187+ npwk,
188+ nbands,
189+ &intermediate_one,
190+ L_ace,
191+ nbands,
192+ h_psi_ace,
193+ nbasis,
194+ &intermediate_zero,
195+ Xi_ace,
196+ nbands);
197+
198+ // clear mem
199+ setmem_complex_op ()(h_psi_ace, 0 , nbands * nbasis);
200+ setmem_complex_op ()(psi_h_psi_ace, 0 , nbands * nbands);
201+ setmem_complex_op ()(L_ace, 0 , nbands * nbands);
202+
203+ }
204+
205+ *ik_ = ik_save;
206+ ModuleBase::timer::tick (" OperatorEXXPW" , " construct_ace" );
207+
208+ }
209+
210+ template <typename T, typename Device>
211+ double OperatorEXXPW<T, Device>::cal_exx_energy_ace(psi::Psi<T, Device>* ppsi_) const
212+ {
213+ double Eexx = 0 ;
214+
215+ psi::Psi<T, Device> psi_ = *ppsi_;
216+ int * ik_ = const_cast <int *>(&this ->ik );
217+ int ik_save = this ->ik ;
218+ for (int i = 0 ; i < wfcpw->nks ; i++)
219+ {
220+ setmem_complex_op ()(h_psi_ace, 0 , psi_.get_nbands () * psi_.get_nbasis ());
221+ *ik_ = i;
222+ psi_.fix_kb (i, 0 );
223+ T* psi_i = psi_.get_pointer ();
224+ act_op_ace (psi_.get_nbands (), psi_.get_nbasis (), 1 , psi_i, h_psi_ace, 0 , true );
225+
226+ for (int nband = 0 ; nband < psi_.get_nbands (); nband++)
227+ {
228+ psi_.fix_kb (i, nband);
229+ T* psi_i_n = psi_.get_pointer ();
230+ T* hpsi_i_n = h_psi_ace + nband * psi_.get_nbasis ();
231+ double wg_i_n = (*wg)(i, nband);
232+ // Eexx += dot(psi_i_n, h_psi_i_n)
233+ Eexx += dot_op ()(psi_.get_nbasis (), psi_i_n, hpsi_i_n, false ) * wg_i_n * 2 ;
234+ }
235+ }
236+
237+ Parallel_Reduce::reduce_pool (Eexx);
238+ *ik_ = ik_save;
239+ return Eexx;
240+ }
241+ template class OperatorEXXPW <std::complex <float >, base_device::DEVICE_CPU>;
242+ template class OperatorEXXPW <std::complex <double >, base_device::DEVICE_CPU>;
243+ #if ((defined __CUDA) || (defined __ROCM))
244+ template class OperatorEXXPW <std::complex <float >, base_device::DEVICE_GPU>;
245+ template class OperatorEXXPW <std::complex <double >, base_device::DEVICE_GPU>;
246+ #endif
247+ }
0 commit comments