@@ -130,10 +130,32 @@ class DiagoBPCGPrepare
130130 psi_local.fix_k (0 );
131131 double start, end;
132132 start = MPI_Wtime ();
133- bpcg.init_iter (psi_local);
134- bpcg.diag (ha,psi_local,en);
135- bpcg.diag (ha,psi_local,en);
136- bpcg.diag (ha,psi_local,en);
133+ using T = std::complex <double >;
134+ const int dim = DIAGOTEST::npw;
135+ const std::vector<T> &h_mat = DIAGOTEST::hmatrix;
136+ auto hpsi_func = [h_mat, dim](T *psi_in, T *hpsi_out,
137+ const int ld_psi, const int nvec) {
138+ auto one = std::make_unique<T>(1.0 );
139+ auto zero = std::make_unique<T>(0.0 );
140+ const T *one_ = one.get ();
141+ const T *zero_ = zero.get ();
142+
143+ base_device::DEVICE_CPU *ctx = {};
144+ // hpsi_out(dim * nvec) = h_mat(dim * dim) * psi_in(dim * nvec)
145+ hsolver::gemm_op<T, base_device::DEVICE_CPU>()(
146+ ctx, ' N' , ' N' ,
147+ dim, nvec, dim,
148+ one_,
149+ h_mat.data (), dim,
150+ psi_in, ld_psi,
151+ zero_,
152+ hpsi_out, ld_psi);
153+ };
154+ bpcg.init_iter (nband, npw);
155+ bpcg.diag (hpsi_func, psi_local.get_pointer (), en);
156+ // bpcg.diag(ha,psi_local,en);
157+ // bpcg.diag(ha,psi_local,en);
158+ // bpcg.diag(ha,psi_local,en);
137159 end = MPI_Wtime ();
138160 // if(mypnum == 0) printf("diago time:%7.3f\n",end-start);
139161 delete [] DIAGOTEST::npw_local;
0 commit comments