@@ -167,11 +167,44 @@ void DiagoBPCG<T, Device>::orth_projection(
167167 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_in);
168168 hsub_in = ct::op::einsum (" ij,kj->ik" , grad_out, psi_in, option);
169169
170+ // this->orth_projection(this->psi, this->hsub, this->grad);
171+ // gemm: hsub_in(n_band x n_band) = grad_out^T(n_band x n_basis) * psi_in(n_basis x n_band)
172+ // gemm_op()(this->ctx,
173+ // 'C',
174+ // 'N',
175+ // this->n_band, //m
176+ // this->n_band, //n
177+ // this->n_dim, //k
178+ // this->one,
179+ // grad_out.data<T>(),
180+ // this->n_basis, //lda
181+ // psi_in.data<T>(),
182+ // this->n_basis, //ldb
183+ // this->zero,
184+ // hsub_in.data<T>(),
185+ // this->n_band); //ldc
186+
170187 // set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
171188 option = ct::EinsumOption (
172189 /* conj_x=*/ false , /* conj_y=*/ false , /* alpha=*/ -1.0 , /* beta=*/ 1.0 , /* Tensor out=*/ &grad_out);
173190 grad_out = ct::op::einsum (" ij,jk->ik" , hsub_in, psi_in, option);
174191
192+ // grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193+ // gemm_op()(this->ctx,
194+ // 'N',
195+ // 'N',
196+ // this->n_basis, //m
197+ // this->n_band, //n
198+ // this->n_band, //k
199+ // this->one,
200+ // psi_in.data<T>(),
201+ // this->n_basis, //lda
202+ // hsub_in.data<T>(),
203+ // this->n_band, //ldb
204+ // this->zero,
205+ // grad_out.data<T>(),
206+ // this->n_basis); //ldc
207+
175208 return ;
176209}
177210
0 commit comments