@@ -97,23 +97,23 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9797 // hsub_out = psi_out * transc(psi_out)
9898 ct::EinsumOption option (
9999 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_out);
100- hsub_out = ct::op::einsum (" ij,kj->ik" , psi_out, psi_out, option);
100+ // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
101101
102102 // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band)
103- // gemm_op()(this->ctx,
104- // 'C',
105- // 'N',
106- // this->n_band, //m
107- // this->n_band, //n
108- // this->n_dim, //k
109- // this->one, //1.0
110- // psi_out.data<T>(),
111- // this->n_basis, //lda
112- // psi_out.data<T>(),
113- // this->n_basis, //ldb
114- // this->zero, //0.0
115- // hsub_out.data<T>(),
116- // this->n_band); //ldc
103+ gemm_op ()(this ->ctx ,
104+ ' C' ,
105+ ' N' ,
106+ this ->n_band , // m
107+ this ->n_band , // n
108+ this ->n_dim , // k
109+ this ->one , // 1.0
110+ psi_out.data <T>(),
111+ this ->n_basis , // lda
112+ psi_out.data <T>(),
113+ this ->n_basis , // ldb
114+ this ->zero , // 0.0
115+ hsub_out.data <T>(),
116+ this ->n_band ); // ldc
117117
118118 // set hsub matrix to lower format;
119119 ct::kernels::set_matrix<T, ct_Device>()(
@@ -165,45 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
165165{
166166 ct::EinsumOption option (
167167 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_in);
168- hsub_in = ct::op::einsum (" ij,kj->ik" , grad_out, psi_in, option);
168+ // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169169
170170 // this->orth_projection(this->psi, this->hsub, this->grad);
171171 // gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
172- // gemm_op()(this->ctx,
173- // 'C',
174- // 'N',
175- // this->n_band, //m
176- // this->n_band, //n
177- // this->n_dim, //k
178- // this->one, //1.0
179- // psi_in.data<T>(),
180- // this->n_basis, //lda
181- // grad_out.data<T>(),
182- // this->n_basis, //ldb
183- // this->zero, //0.0
184- // hsub_in.data<T>(),
185- // this->n_band); //ldc
172+ gemm_op ()(this ->ctx ,
173+ ' C' ,
174+ ' N' ,
175+ this ->n_band , // m
176+ this ->n_band , // n
177+ this ->n_dim , // k
178+ this ->one , // 1.0
179+ psi_in.data <T>(),
180+ this ->n_basis , // lda
181+ grad_out.data <T>(),
182+ this ->n_basis , // ldb
183+ this ->zero , // 0.0
184+ hsub_in.data <T>(),
185+ this ->n_band ); // ldc
186186
187187 // set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
188188 option = ct::EinsumOption (
189189 /* conj_x=*/ false , /* conj_y=*/ false , /* alpha=*/ -1.0 , /* beta=*/ 1.0 , /* Tensor out=*/ &grad_out);
190- grad_out = ct::op::einsum (" ij,jk->ik" , hsub_in, psi_in, option);
190+ // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191191
192192 // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193- // gemm_op()(this->ctx,
194- // 'N',
195- // 'N',
196- // this->n_dim, //m
197- // this->n_band, //n
198- // this->n_band, //k
199- // this->neg_one, //-1.0
200- // psi_in.data<T>(),
201- // this->n_basis, //lda
202- // hsub_in.data<T>(),
203- // this->n_band, //ldb
204- // this->one, //1.0
205- // grad_out.data<T>(),
206- // this->n_basis); //ldc
193+ gemm_op ()(this ->ctx ,
194+ ' N' ,
195+ ' N' ,
196+ this ->n_dim , // m
197+ this ->n_band , // n
198+ this ->n_band , // k
199+ this ->neg_one , // -1.0
200+ psi_in.data <T>(),
201+ this ->n_basis , // lda
202+ hsub_in.data <T>(),
203+ this ->n_band , // ldb
204+ this ->one , // 1.0
205+ grad_out.data <T>(),
206+ this ->n_basis ); // ldc
207207
208208 return ;
209209}
@@ -263,23 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
263263 // it controls the ops to use the corresponding device to calculate results
264264 ct::EinsumOption option (
265265 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_out);
266- hsub_out = ct::op::einsum (" ij,kj->ik" , psi_in, hpsi_in, option);
266+ // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
267267
268268 // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
269- // gemm_op()(this->ctx,
270- // 'C',
271- // 'N',
272- // this->n_band, //m
273- // this->n_band, //n
274- // this->n_dim, //k
275- // this->one, //1.0
276- // hpsi_in.data<T>(),
277- // this->n_basis, //lda
278- // psi_in.data<T>(),
279- // this->n_basis, //ldb
280- // this->zero, //0.0
281- // hsub_out.data<T>(),
282- // this->n_band); //ldc
269+ gemm_op ()(this ->ctx ,
270+ ' C' ,
271+ ' N' ,
272+ this ->n_band , // m
273+ this ->n_band , // n
274+ this ->n_dim , // k
275+ this ->one , // 1.0
276+ hpsi_in.data <T>(),
277+ this ->n_basis , // lda
278+ psi_in.data <T>(),
279+ this ->n_basis , // ldb
280+ this ->zero , // 0.0
281+ hsub_out.data <T>(),
282+ this ->n_band ); // ldc
283283
284284 ct::kernels::lapack_dnevd<T, ct_Device>()(' V' , ' U' , hsub_out.data <T>(), this ->n_band , eigenvalue_out.data <Real>());
285285
0 commit comments