@@ -106,12 +106,12 @@ void DiagoBPCG<T, Device>::orth_cholesky(
106106 this ->n_band , // m
107107 this ->n_band , // n
108108 this ->n_dim , // k
109- this ->one ,
109+ this ->one , // 1.0
110110 psi_out.data <T>(),
111111 this ->n_basis , // lda
112112 psi_out.data <T>(),
113113 this ->n_basis , // ldb
114- this ->zero ,
114+ this ->zero , // 0.0
115115 hsub_out.data <T>(),
116116 this ->n_band ); // ldc
117117
@@ -175,12 +175,12 @@ void DiagoBPCG<T, Device>::orth_projection(
175175 this ->n_band , // m
176176 this ->n_band , // n
177177 this ->n_dim , // k
178- this ->one ,
178+ this ->one , // 1.0
179179 psi_in.data <T>(),
180180 this ->n_basis , // lda
181181 grad_out.data <T>(),
182182 this ->n_basis , // ldb
183- this ->zero ,
183+ this ->zero , // 0.0
184184 hsub_in.data <T>(),
185185 this ->n_band ); // ldc
186186
@@ -189,21 +189,21 @@ void DiagoBPCG<T, Device>::orth_projection(
189189 /* conj_x=*/ false , /* conj_y=*/ false , /* alpha=*/ -1.0 , /* beta=*/ 1.0 , /* Tensor out=*/ &grad_out);
190190 // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191191
192- // grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
192+ // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193193 gemm_op ()(this ->ctx ,
194194 ' N' ,
195195 ' N' ,
196- this ->n_dim , // m
196+ this ->n_dim , // m
197197 this ->n_band , // n
198198 this ->n_band , // k
199- this ->neg_one ,
199+ this ->neg_one , // -1.0
200200 psi_in.data <T>(),
201201 this ->n_basis , // lda
202202 hsub_in.data <T>(),
203203 this ->n_band , // ldb
204- this ->one ,
204+ this ->one , // 1.0
205205 grad_out.data <T>(),
206- this ->n_basis ); // ldc
206+ this ->n_basis ); // ldc
207207
208208 return ;
209209}
@@ -224,15 +224,15 @@ void DiagoBPCG<T, Device>::rotate_wf(
224224 gemm_op ()(this ->ctx ,
225225 ' N' ,
226226 ' N' ,
227- this ->n_dim , // m
228- this ->n_band , // n
227+ this ->n_dim , // m
228+ this ->n_band , // n
229229 this ->n_band , // k
230- this ->one ,
230+ this ->one , // 1.0
231231 psi_out.data <T>(),
232- this ->n_basis , // lda
232+ this ->n_basis , // lda
233233 hsub_in.data <T>(),
234- this ->n_band , // ldb
235- this ->zero ,
234+ this ->n_band , // ldb
235+ this ->zero , // 0.0
236236 workspace_in.data <T>(),
237237 this ->n_basis ); // ldc
238238
@@ -263,7 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
263263 // it controls the ops to use the corresponding device to calculate results
264264 ct::EinsumOption option (
265265 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_out);
266- hsub_out = ct::op::einsum (" ij,kj->ik" , psi_in, hpsi_in, option);
266+ // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
267+
268+ // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
269+ gemm_op ()(this ->ctx ,
270+ ' C' ,
271+ ' N' ,
272+ this ->n_band , // m
273+ this ->n_band , // n
274+ this ->n_dim , // k
275+ this ->one , // 1.0
276+ hpsi_in.data <T>(),
277+ this ->n_basis , // lda
278+ psi_in.data <T>(),
279+ this ->n_basis , // ldb
280+ this ->zero , // 0.0
281+ hsub_out.data <T>(),
282+ this ->n_band ); // ldc
267283
268284 ct::kernels::lapack_dnevd<T, ct_Device>()(' V' , ' U' , hsub_out.data <T>(), this ->n_band , eigenvalue_out.data <Real>());
269285
0 commit comments