Skip to content

Commit 31baa6b

Browse files
committed
Subsitute gemm for einsum in rotate_wf
1 parent 806de4a commit 31baa6b

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ void DiagoBPCG<T, Device>::orth_cholesky(
115115
hsub_out.data<T>(),
116116
this->n_band); //ldc
117117

118+
// Parallel_Reduce::reduce_pool(hsub_out.data<T>(), this->n_band * this->n_band);
119+
118120
// set hsub matrix to lower format;
119121
ct::kernels::set_matrix<T, ct_Device>()(
120122
'L', hsub_out.data<T>(), this->n_band);
@@ -184,6 +186,8 @@ void DiagoBPCG<T, Device>::orth_projection(
184186
hsub_in.data<T>(),
185187
this->n_band); //ldc
186188

189+
// Parallel_Reduce::reduce_pool(hsub_in.data<T>(), this->n_band * this->n_band);
190+
187191
// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
188192
option = ct::EinsumOption(
189193
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
@@ -205,6 +209,8 @@ void DiagoBPCG<T, Device>::orth_projection(
205209
grad_out.data<T>(),
206210
this->n_basis); //ldc
207211

212+
// Parallel_Reduce::reduce_pool(grad_out.data<T>(), this->n_basis * this->n_band);
213+
208214
return;
209215
}
210216

@@ -216,25 +222,27 @@ void DiagoBPCG<T, Device>::rotate_wf(
216222
{
217223
ct::EinsumOption option(
218224
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in);
219-
workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
225+
// workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
220226

221227
// this->rotate_wf(hsub_out, psi_out, workspace_in);
222228
// this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
223229
// gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
224-
// gemm_op()(this->ctx,
225-
// 'N',
226-
// 'N',
227-
// this->n_basis, //m
228-
// this->n_band, //n
229-
// this->n_band, //k
230-
// this->one, //1.0
231-
// psi_out.data<T>(),
232-
// this->n_basis, //lda
233-
// hsub_in.data<T>(),
234-
// this->n_band, //ldb
235-
// this->zero, //0.0
236-
// workspace_in.data<T>(),
237-
// this->n_basis); //ldc
230+
gemm_op()(this->ctx,
231+
'N',
232+
'N',
233+
this->n_basis, //m
234+
this->n_band, //n
235+
this->n_band, //k
236+
this->one, //1.0
237+
psi_out.data<T>(),
238+
this->n_basis, //lda
239+
hsub_in.data<T>(),
240+
this->n_band, //ldb
241+
this->zero, //0.0
242+
workspace_in.data<T>(),
243+
this->n_basis); //ldc
244+
245+
// Parallel_Reduce::reduce_pool(workspace_in.data<T>(), this->n_basis * this->n_band);
238246

239247
syncmem_complex_op()(psi_out.template data<T>(), workspace_in.template data<T>(), this->n_band * this->n_basis);
240248

@@ -281,6 +289,8 @@ void DiagoBPCG<T, Device>::diag_hsub(
281289
hsub_out.data<T>(),
282290
this->n_band); //ldc
283291

292+
// Parallel_Reduce::reduce_pool(hsub_out.data<T>(), this->n_band * this->n_band);
293+
284294
ct::kernels::lapack_dnevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
285295

286296
return;

0 commit comments

Comments
 (0)