Skip to content

Commit e2dc4a1

Browse files
committed
replace einsum by gemm in rotate_wf
1 parent d901d00 commit e2dc4a1

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,25 @@ void DiagoBPCG<T, Device>::rotate_wf(
216216
{
217217
ct::EinsumOption option(
218218
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in);
219-
workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
219+
// workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
220+
221+
// this->rotate_wf(hsub_out, psi_out, workspace_in);
222+
// this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
223+
// gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
224+
gemm_op()(this->ctx,
225+
'N',
226+
'N',
227+
this->n_dim, //m
228+
this->n_band, //n
229+
this->n_band, //k
230+
this->one,
231+
psi_out.data<T>(),
232+
this->n_basis, //lda
233+
hsub_in.data<T>(),
234+
this->n_band, //ldb
235+
this->zero,
236+
workspace_in.data<T>(),
237+
this->n_basis); //ldc
220238

221239
syncmem_complex_op()(psi_out.template data<T>(), workspace_in.template data<T>(), this->n_band * this->n_basis);
222240

0 commit comments

Comments
 (0)