@@ -35,6 +35,7 @@ void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
3535 // Specify the problem size n_basis, n_band, while lda is n_basis
3636 this ->n_band = psi_in.get_nbands ();
3737 this ->n_basis = psi_in.get_nbasis ();
38+ this ->n_dim = psi_in.get_current_nbas ();
3839
3940 // All column major tensors
4041
@@ -93,24 +94,24 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9394 ct::Tensor& hsub_out)
9495{
9596 // hsub_out = psi_out * transc(psi_out)
96- ct::EinsumOption option (
97- /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_out);
98- hsub_out = ct::op::einsum (" ij,kj->ik" , psi_out, psi_out, option);
97+ // ct::EinsumOption option(
98+ // /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
99+ // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
99100 // using gemm instead einsum for different leading dimension and nbasis
100- // gemm_op<T, Device>()(this->ctx,
101- // 'N',
102- // 'N ',
103- // this->dim ,
104- // notconv ,
105- // nbase ,
106- // this->one,
107- // hphi ,
108- // this->dim ,
109- // vcc ,
110- // this->nbase_x ,
111- // this->zero,
112- // psi_iter + (nbase) * this->dim ,
113- // this->dim );
101+ gemm_op<T, Device>()(this ->ctx ,
102+ ' N' ,
103+ ' C ' ,
104+ this ->n_band ,
105+ this -> n_band ,
106+ this -> n_dim ,
107+ this ->one ,
108+ psi_out. data <T>() ,
109+ this ->n_basis ,
110+ psi_out. data <T>() ,
111+ this ->n_basis ,
112+ this ->zero ,
113+ hsub_out. data <T>() ,
114+ this ->n_band );
114115
115116 // set hsub matrix to lower format;
116117 ct::kernels::set_matrix<T, ct_Device>()(
0 commit comments