Skip to content

Commit 8d549fd

Browse files
committed
Test: change bpcg tests to fit new interface
1 parent 032a150 commit 8d549fd

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

source/module_hsolver/test/diago_bpcg_test.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,32 @@ class DiagoBPCGPrepare
130130
psi_local.fix_k(0);
131131
double start, end;
132132
start = MPI_Wtime();
133-
bpcg.init_iter(psi_local);
134-
bpcg.diag(ha,psi_local,en);
135-
bpcg.diag(ha,psi_local,en);
136-
bpcg.diag(ha,psi_local,en);
133+
using T = std::complex<double>;
134+
const int dim = DIAGOTEST::npw;
135+
const std::vector<T> &h_mat = DIAGOTEST::hmatrix;
136+
auto hpsi_func = [h_mat, dim](T *psi_in, T *hpsi_out,
137+
const int ld_psi, const int nvec) {
138+
auto one = std::make_unique<T>(1.0);
139+
auto zero = std::make_unique<T>(0.0);
140+
const T *one_ = one.get();
141+
const T *zero_ = zero.get();
142+
143+
base_device::DEVICE_CPU *ctx = {};
144+
// hpsi_out(dim * nvec) = h_mat(dim * dim) * psi_in(dim * nvec)
145+
hsolver::gemm_op<T, base_device::DEVICE_CPU>()(
146+
ctx, 'N', 'N',
147+
dim, nvec, dim,
148+
one_,
149+
h_mat.data(), dim,
150+
psi_in, ld_psi,
151+
zero_,
152+
hpsi_out, ld_psi);
153+
};
154+
bpcg.init_iter(nband, npw);
155+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
156+
// bpcg.diag(ha,psi_local,en);
157+
// bpcg.diag(ha,psi_local,en);
158+
// bpcg.diag(ha,psi_local,en);
137159
end = MPI_Wtime();
138160
//if(mypnum == 0) printf("diago time:%7.3f\n",end-start);
139161
delete [] DIAGOTEST::npw_local;

0 commit comments

Comments
 (0)