@@ -165,45 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
165165{
166166 ct::EinsumOption option (
167167 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_in);
168- hsub_in = ct::op::einsum (" ij,kj->ik" , grad_out, psi_in, option);
168+ // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169169
170170 // this->orth_projection(this->psi, this->hsub, this->grad);
171- // gemm: hsub_in(n_band x n_band) = grad_out ^T(n_band x n_basis) * psi_in (n_basis x n_band)
172- // gemm_op()(this->ctx,
173- // 'C',
174- // 'N',
175- // this->n_band, //m
176- // this->n_band, //n
177- // this->n_dim, //k
178- // this->one,
179- // grad_out .data<T>(),
180- // this->n_basis, //lda
181- // psi_in .data<T>(),
182- // this->n_basis, //ldb
183- // this->zero,
184- // hsub_in.data<T>(),
185- // this->n_band); //ldc
171+ // gemm: hsub_in(n_band x n_band) = psi_in ^T(n_band x n_basis) * grad_out (n_basis x n_band)
172+ gemm_op ()(this ->ctx ,
173+ ' C' ,
174+ ' N' ,
175+ this ->n_band , // m
176+ this ->n_band , // n
177+ this ->n_dim , // k
178+ this ->one ,
179+ psi_in .data <T>(),
180+ this ->n_basis , // lda
181+ grad_out .data <T>(),
182+ this ->n_basis , // ldb
183+ this ->zero ,
184+ hsub_in.data <T>(),
185+ this ->n_band ); // ldc
186186
187187 // set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
188188 option = ct::EinsumOption (
189189 /* conj_x=*/ false , /* conj_y=*/ false , /* alpha=*/ -1.0 , /* beta=*/ 1.0 , /* Tensor out=*/ &grad_out);
190- grad_out = ct::op::einsum (" ij,jk->ik" , hsub_in, psi_in, option);
190+ // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191191
192192 // grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193- // gemm_op()(this->ctx,
194- // 'N',
195- // 'N',
196- // this->n_basis , //m
197- // this->n_band, //n
198- // this->n_band, //k
199- // this->one ,
200- // psi_in.data<T>(),
201- // this->n_basis, //lda
202- // hsub_in.data<T>(),
203- // this->n_band, //ldb
204- // this->zero ,
205- // grad_out.data<T>(),
206- // this->n_basis); //ldc
193+ gemm_op ()(this ->ctx ,
194+ ' N' ,
195+ ' N' ,
196+ this ->n_dim , // m
197+ this ->n_band , // n
198+ this ->n_band , // k
199+ this ->neg_one ,
200+ psi_in.data <T>(),
201+ this ->n_basis , // lda
202+ hsub_in.data <T>(),
203+ this ->n_band , // ldb
204+ this ->one ,
205+ grad_out.data <T>(),
206+ this ->n_basis ); // ldc
207207
208208 return ;
209209}
0 commit comments