@@ -115,6 +115,8 @@ void DiagoBPCG<T, Device>::orth_cholesky(
115115 hsub_out.data <T>(),
116116 this ->n_band ); // ldc
117117
118+ // Parallel_Reduce::reduce_pool(hsub_out.data<T>(), this->n_band * this->n_band);
119+
118120 // set hsub matrix to lower format;
119121 ct::kernels::set_matrix<T, ct_Device>()(
120122 ' L' , hsub_out.data <T>(), this ->n_band );
@@ -184,6 +186,8 @@ void DiagoBPCG<T, Device>::orth_projection(
184186 hsub_in.data <T>(),
185187 this ->n_band ); // ldc
186188
189+ // Parallel_Reduce::reduce_pool(hsub_in.data<T>(), this->n_band * this->n_band);
190+
187191 // set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
188192 option = ct::EinsumOption (
189193 /* conj_x=*/ false , /* conj_y=*/ false , /* alpha=*/ -1.0 , /* beta=*/ 1.0 , /* Tensor out=*/ &grad_out);
@@ -205,6 +209,8 @@ void DiagoBPCG<T, Device>::orth_projection(
205209 grad_out.data <T>(),
206210 this ->n_basis ); // ldc
207211
212+ // Parallel_Reduce::reduce_pool(grad_out.data<T>(), this->n_basis * this->n_band);
213+
208214 return ;
209215}
210216
@@ -216,25 +222,27 @@ void DiagoBPCG<T, Device>::rotate_wf(
216222{
217223 ct::EinsumOption option (
218224 /* conj_x=*/ false , /* conj_y=*/ false , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &workspace_in);
219- workspace_in = ct::op::einsum (" ij,jk->ik" , hsub_in, psi_out, option);
225+ // workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
220226
221227 // this->rotate_wf(hsub_out, psi_out, workspace_in);
222228 // this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
223229 // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
224- // gemm_op()(this->ctx,
225- // 'N',
226- // 'N',
227- // this->n_basis, //m
228- // this->n_band, //n
229- // this->n_band, //k
230- // this->one, //1.0
231- // psi_out.data<T>(),
232- // this->n_basis, //lda
233- // hsub_in.data<T>(),
234- // this->n_band, //ldb
235- // this->zero, //0.0
236- // workspace_in.data<T>(),
237- // this->n_basis); //ldc
230+ gemm_op ()(this ->ctx ,
231+ ' N' ,
232+ ' N' ,
233+ this ->n_basis , // m
234+ this ->n_band , // n
235+ this ->n_band , // k
236+ this ->one , // 1.0
237+ psi_out.data <T>(),
238+ this ->n_basis , // lda
239+ hsub_in.data <T>(),
240+ this ->n_band , // ldb
241+ this ->zero , // 0.0
242+ workspace_in.data <T>(),
243+ this ->n_basis ); // ldc
244+
245+ // Parallel_Reduce::reduce_pool(workspace_in.data<T>(), this->n_basis * this->n_band);
238246
239247 syncmem_complex_op ()(psi_out.template data <T>(), workspace_in.template data <T>(), this ->n_band * this ->n_basis );
240248
@@ -281,6 +289,8 @@ void DiagoBPCG<T, Device>::diag_hsub(
281289 hsub_out.data <T>(),
282290 this ->n_band ); // ldc
283291
292+ // Parallel_Reduce::reduce_pool(hsub_out.data<T>(), this->n_band * this->n_band);
293+
284294 ct::kernels::lapack_dnevd<T, ct_Device>()(' V' , ' U' , hsub_out.data <T>(), this ->n_band , eigenvalue_out.data <Real>());
285295
286296 return ;
0 commit comments