@@ -32,16 +32,16 @@ using namespace hsolver;
3232 * @note Auxiliary memory is allocated in the constructor and deallocated in the destructor.
3333 */
3434template <typename T, typename Device>
35- DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
35+ DiagoDavid<T, Device>::DiagoDavid(PreFunc&& precondition_in,
3636 const int nband_in,
3737 const int dim_in,
3838 const int david_ndim_in,
3939 const bool use_paw_in,
4040 const diag_comm_info& diag_comm_in)
41- : nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in)
41+ : nband(nband_in), dim(dim_in), nbase_x(david_ndim_in* nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in),
42+ precondition (std::forward<PreFunc>(precondition_in))
4243{
4344 this ->device = base_device::get_device_type<Device>(this ->ctx );
44- this ->precondition = precondition_in;
4545
4646 this ->one = &one_;
4747 this ->zero = &zero_;
@@ -110,15 +110,6 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
110110 // lagrange_matrix(nband, nband); // for orthogonalization
111111 resmem_complex_op ()(this ->ctx , this ->lagrange_matrix , nband * nband);
112112 setmem_complex_op ()(this ->ctx , this ->lagrange_matrix , 0 , nband * nband);
113-
114- #if defined(__CUDA) || defined(__ROCM)
115- // device precondition array
116- if (this ->device == base_device::GpuDevice)
117- {
118- resmem_var_op ()(this ->ctx , this ->d_precondition , dim);
119- syncmem_var_h2d_op ()(this ->ctx , this ->cpu_ctx , this ->d_precondition , this ->precondition , dim);
120- }
121- #endif
122113}
123114
124115/* *
@@ -139,13 +130,6 @@ DiagoDavid<T, Device>::~DiagoDavid()
139130 delmem_complex_op ()(this ->ctx , this ->vcc );
140131 delmem_complex_op ()(this ->ctx , this ->lagrange_matrix );
141132 base_device::memory::delete_memory_op<Real, base_device::DEVICE_CPU>()(this ->cpu_ctx , this ->eigenvalue );
142- // If the device is a GPU device, free the d_precondition array.
143- #if defined(__CUDA) || defined(__ROCM)
144- if (this ->device == base_device::GpuDevice)
145- {
146- delmem_var_op ()(this ->ctx , this ->d_precondition );
147- }
148- #endif
149133}
150134
151135template <typename T, typename Device>
@@ -499,40 +483,14 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
499483 dim // LDC: if(N) max(1, m)
500484 );
501485 // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
502-
503486 // Preconditioning
504487 // basis[nbase] = T * basis[nbase] = T * (H - lambda * S) * psi
505488 // where T, the preconditioner, is an approximate inverse of H
506489 // T is a diagonal stored in array `precondition`
507490 // to do preconditioning, divide each column of basis by the corresponding element of precondition
508- for (int m = 0 ; m < notconv; m++)
509- {
510- // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
511- if (this ->device == base_device::GpuDevice)
512- {
513- #if defined(__CUDA) || defined(__ROCM)
514- vector_div_vector_op<T, Device>()(this ->ctx ,
515- dim,
516- basis + dim*(nbase + m),
517- basis + dim*(nbase + m),
518- this ->d_precondition );
519- #endif
520- }
521- else
522- {
523- vector_div_vector_op<T, Device>()(this ->ctx ,
524- dim,
525- basis + dim*(nbase + m),
526- basis + dim*(nbase + m),
527- this ->precondition );
528- }
529- // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
530- // for (int ig = 0; ig < dim; ig++)
531- // {
532- // ppsi[ig] /= this->precondition[ig];
533- // }
534- // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
535- }
491+ // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
492+ this ->precondition (basis + dim * nbase, dim, notconv);
493+ // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
536494
537495 // there is a nbase to nbase + notconv band orthogonalise
538496 // plan for SchmidtOrth
0 commit comments