@@ -27,14 +27,17 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
2727template <typename T, typename Device>
2828DiagoBPCG<T, Device>::~DiagoBPCG () {
2929 // Note, we do not need to free the h_prec and psi pointer as they are refs to the outside data
30- delete this ->grad_wrapper ;
30+ // delete this->grad_wrapper;
3131}
3232
3333template <typename T, typename Device>
34- void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in ) {
34+ void DiagoBPCG<T, Device>::init_iter(/* const T *psi_in, */ const int nband, const int nbasis ) {
3535 // Specify the problem size n_basis, n_band, while lda is n_basis
36- this ->n_band = psi_in.get_nbands ();
37- this ->n_basis = psi_in.get_nbasis ();
36+ // this->n_band = psi_in.get_nbands();
37+ // this->n_basis = psi_in.get_nbasis();
38+ this ->n_band = nband;
39+ this ->n_basis = nbasis;
40+
3841
3942 // All column major tensors
4043
@@ -52,8 +55,9 @@ void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
5255 this ->prec = std::move (ct::Tensor (r_type, device_type, {this ->n_basis }));
5356
5457 // TODO: Remove class Psi, using ct::Tensor instead!
55- this ->grad_wrapper = new psi::Psi<T, Device>(1 , this ->n_band , this ->n_basis , psi_in.get_ngk_pointer ());
56- this ->grad = std::move (ct::TensorMap (grad_wrapper->get_pointer (), t_type, device_type, {this ->n_band , this ->n_basis }));
58+ // this->grad_wrapper = new psi::Psi<T, Device>(1, this->n_band, this->n_basis, psi_in.get_ngk_pointer());
59+ // this->grad = std::move(ct::TensorMap(grad_wrapper->get_pointer(), t_type, device_type, {this->n_band, this->n_basis}));
60+ this ->grad = std::move (ct::Tensor (t_type, device_type, {this ->n_band , this ->n_basis }));
5761}
5862
5963template <typename T, typename Device>
@@ -174,16 +178,19 @@ void DiagoBPCG<T, Device>::rotate_wf(
174178
175179template <typename T, typename Device>
176180void DiagoBPCG<T, Device>::calc_hpsi_with_block(
177- hamilt::Hamilt<T, Device>* hamilt_in,
178- const psi::Psi<T, Device>& psi_in,
181+ // hamilt::Hamilt<T, Device>* hamilt_in,
182+ const HPsiFunc& hpsi_func,
183+ // const psi::Psi<T, Device>& psi_in,
184+ T *psi_in,
179185 ct::Tensor& hpsi_out)
180186{
181187 // calculate all-band hpsi
182- psi::Range all_bands_range (1 , psi_in.get_current_k (), 0 , psi_in.get_nbands () - 1 );
183- hpsi_info info (&psi_in, all_bands_range, hpsi_out.data <T>());
184- hamilt_in->ops ->hPsi (info);
188+ // psi::Range all_bands_range(1, psi_in.get_current_k(), 0, psi_in.get_nbands() - 1);
189+ // hpsi_info info(&psi_in, all_bands_range, hpsi_out.data<T>());
190+ // hamilt_in->ops->hPsi(info);
191+ hpsi_func (psi_in, hpsi_out.data <T>(), this ->n_basis , this ->n_band );
185192
186- return ;
193+ // return;
187194}
188195
189196template <typename T, typename Device>
@@ -207,16 +214,18 @@ void DiagoBPCG<T, Device>::diag_hsub(
207214
208215template <typename T, typename Device>
209216void DiagoBPCG<T, Device>::calc_hsub_with_block(
210- hamilt::Hamilt<T, Device> *hamilt_in,
211- const psi::Psi<T, Device> &psi_in,
217+ // hamilt::Hamilt<T, Device>* hamilt_in,
218+ const HPsiFunc& hpsi_func,
219+ // const psi::Psi<T, Device>& psi_in,
220+ T *psi_in,
212221 ct::Tensor& psi_out,
213222 ct::Tensor& hpsi_out,
214223 ct::Tensor& hsub_out,
215224 ct::Tensor& workspace_in,
216225 ct::Tensor& eigenvalue_out)
217226{
218227 // Apply the H operator to psi and obtain the hpsi matrix.
219- this ->calc_hpsi_with_block (hamilt_in , psi_in, hpsi_out);
228+ this ->calc_hpsi_with_block (hpsi_func , psi_in, hpsi_out);
220229
221230 // Diagonalization of the subspace matrix.
222231 this ->diag_hsub (psi_out,hpsi_out, hsub_out, eigenvalue_out);
@@ -250,19 +259,21 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(
250259
251260template <typename T, typename Device>
252261void DiagoBPCG<T, Device>::diag(
253- hamilt::Hamilt<T, Device>* hamilt_in,
254- psi::Psi<T, Device>& psi_in,
262+ // hamilt::Hamilt<T, Device>* hamilt_in,
263+ const HPsiFunc& hpsi_func,
264+ // psi::Psi<T, Device>& psi_in,
265+ T *psi_in,
255266 Real* eigenvalue_in)
256267{
257268 const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
258269 // Get the pointer of the input psi
259- this ->psi = std::move (ct::TensorMap (psi_in.get_pointer (), t_type, device_type, {this ->n_band , this ->n_basis }));
270+ this ->psi = std::move (ct::TensorMap (psi_in /* psi_in .get_pointer()*/ , t_type, device_type, {this ->n_band , this ->n_basis }));
260271
261272 // Update the precondition array
262273 this ->calc_prec ();
263274
264275 // Improving the initial guess of the wave function psi through a subspace diagonalization.
265- this ->calc_hsub_with_block (hamilt_in , psi_in, this ->psi , this ->hpsi , this ->hsub , this ->work , this ->eigen );
276+ this ->calc_hsub_with_block (hpsi_func , psi_in, this ->psi , this ->hpsi , this ->hsub , this ->work , this ->eigen );
266277
267278 setmem_complex_op ()(this ->grad_old .template data <T>(), 0 , this ->n_basis * this ->n_band );
268279
@@ -293,7 +304,7 @@ void DiagoBPCG<T, Device>::diag(
293304 syncmem_complex_op ()(this ->grad_old .template data <T>(), this ->grad .template data <T>(), n_basis * n_band);
294305
295306 // Calculate H|grad> matrix
296- this ->calc_hpsi_with_block (hamilt_in , this ->grad_wrapper [0 ], this ->hgrad );
307+ this ->calc_hpsi_with_block (hpsi_func , this ->grad . data <T>(), /* this-> grad_wrapper[0],*/ this ->hgrad );
297308
298309 // optimize psi as well as the hpsi
299310 // 1. normalize grad
@@ -305,7 +316,7 @@ void DiagoBPCG<T, Device>::diag(
305316 this ->orth_cholesky (this ->work , this ->psi , this ->hpsi , this ->hsub );
306317
307318 if (current_scf_iter == 1 && ntry % this ->nline == 0 ) {
308- this ->calc_hsub_with_block (hamilt_in , psi_in, this ->psi , this ->hpsi , this ->hsub , this ->work , this ->eigen );
319+ this ->calc_hsub_with_block (hpsi_func , psi_in, this ->psi , this ->hpsi , this ->hsub , this ->work , this ->eigen );
309320 }
310321 } while (ntry < max_iter && this ->test_error (this ->err_st , this ->all_band_cg_thr ));
311322
0 commit comments