@@ -22,6 +22,10 @@ DiagoBPCG<T, Device>::DiagoBPCG(const Real* precondition_in)
2222 this ->device_type = ct::DeviceTypeToEnum<Device>::value;
2323
2424 this ->h_prec = std::move (ct::TensorMap ((void *) precondition_in, r_type, device_type, {this ->n_basis }));
25+
26+ this ->one = &one_;
27+ this ->zero = &zero_;
28+ this ->neg_one = &neg_one_;
2529}
2630
2731template <typename T, typename Device>
@@ -30,11 +34,11 @@ DiagoBPCG<T, Device>::~DiagoBPCG() {
3034}
3135
3236template <typename T, typename Device>
33- void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
37+ void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis, const int ndim ) {
3438 // Specify the problem size n_basis, n_band, while lda is n_basis
3539 this ->n_band = nband;
3640 this ->n_basis = nbasis;
37-
41+ this -> n_dim = ndim;
3842
3943 // All column major tensors
4044
@@ -93,7 +97,23 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9397 // hsub_out = psi_out * transc(psi_out)
9498 ct::EinsumOption option (
9599 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_out);
96- hsub_out = ct::op::einsum (" ij,kj->ik" , psi_out, psi_out, option);
100+ // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
101+
102+ // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band)
103+ gemm_op ()(this ->ctx ,
104+ ' C' ,
105+ ' N' ,
106+ this ->n_band , // m
107+ this ->n_band , // n
108+ this ->n_dim , // k
109+ this ->one , // 1.0
110+ psi_out.data <T>(),
111+ this ->n_basis , // lda
112+ psi_out.data <T>(),
113+ this ->n_basis , // ldb
114+ this ->zero , // 0.0
115+ hsub_out.data <T>(),
116+ this ->n_band ); // ldc
97117
98118 // set hsub matrix to lower format;
99119 ct::kernels::set_matrix<T, ct_Device>()(
@@ -145,12 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
145165{
146166 ct::EinsumOption option (
147167 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_in);
148- hsub_in = ct::op::einsum (" ij,kj->ik" , grad_out, psi_in, option);
168+ // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169+
170+ // this->orth_projection(this->psi, this->hsub, this->grad);
171+ // gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
172+ gemm_op ()(this ->ctx ,
173+ ' C' ,
174+ ' N' ,
175+ this ->n_band , // m
176+ this ->n_band , // n
177+ this ->n_dim , // k
178+ this ->one , // 1.0
179+ psi_in.data <T>(),
180+ this ->n_basis , // lda
181+ grad_out.data <T>(),
182+ this ->n_basis , // ldb
183+ this ->zero , // 0.0
184+ hsub_in.data <T>(),
185+ this ->n_band ); // ldc
149186
150187 // set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
151188 option = ct::EinsumOption (
152189 /* conj_x=*/ false , /* conj_y=*/ false , /* alpha=*/ -1.0 , /* beta=*/ 1.0 , /* Tensor out=*/ &grad_out);
153- grad_out = ct::op::einsum (" ij,jk->ik" , hsub_in, psi_in, option);
190+ // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191+
192+ // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193+ gemm_op ()(this ->ctx ,
194+ ' N' ,
195+ ' N' ,
196+ this ->n_dim , // m
197+ this ->n_band , // n
198+ this ->n_band , // k
199+ this ->neg_one , // -1.0
200+ psi_in.data <T>(),
201+ this ->n_basis , // lda
202+ hsub_in.data <T>(),
203+ this ->n_band , // ldb
204+ this ->one , // 1.0
205+ grad_out.data <T>(),
206+ this ->n_basis ); // ldc
154207
155208 return ;
156209}
@@ -165,6 +218,24 @@ void DiagoBPCG<T, Device>::rotate_wf(
165218 /* conj_x=*/ false , /* conj_y=*/ false , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &workspace_in);
166219 workspace_in = ct::op::einsum (" ij,jk->ik" , hsub_in, psi_out, option);
167220
221+ // this->rotate_wf(hsub_out, psi_out, workspace_in);
222+ // this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
223+ // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
224+ // gemm_op()(this->ctx,
225+ // 'N',
226+ // 'N',
227+ // this->n_basis, //m
228+ // this->n_band, //n
229+ // this->n_band, //k
230+ // this->one, //1.0
231+ // psi_out.data<T>(),
232+ // this->n_basis, //lda
233+ // hsub_in.data<T>(),
234+ // this->n_band, //ldb
235+ // this->zero, //0.0
236+ // workspace_in.data<T>(),
237+ // this->n_basis); //ldc
238+
168239 syncmem_complex_op ()(psi_out.template data <T>(), workspace_in.template data <T>(), this ->n_band * this ->n_basis );
169240
170241 return ;
@@ -192,7 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
192263 // it controls the ops to use the corresponding device to calculate results
193264 ct::EinsumOption option (
194265 /* conj_x=*/ false , /* conj_y=*/ true , /* alpha=*/ 1.0 , /* beta=*/ 0.0 , /* Tensor out=*/ &hsub_out);
195- hsub_out = ct::op::einsum (" ij,kj->ik" , psi_in, hpsi_in, option);
266+ // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
267+
268+ // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
269+ gemm_op ()(this ->ctx ,
270+ ' C' ,
271+ ' N' ,
272+ this ->n_band , // m
273+ this ->n_band , // n
274+ this ->n_dim , // k
275+ this ->one , // 1.0
276+ hpsi_in.data <T>(),
277+ this ->n_basis , // lda
278+ psi_in.data <T>(),
279+ this ->n_basis , // ldb
280+ this ->zero , // 0.0
281+ hsub_out.data <T>(),
282+ this ->n_band ); // ldc
196283
197284 ct::kernels::lapack_dnevd<T, ct_Device>()(' V' , ' U' , hsub_out.data <T>(), this ->n_band , eigenvalue_out.data <Real>());
198285
0 commit comments