Skip to content

Commit 3a70e1b

Browse files
committed
replace einsum by gemm in diag_hsub
1 parent e2dc4a1 commit 3a70e1b

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ void DiagoBPCG<T, Device>::orth_cholesky(
106106
this->n_band, //m
107107
this->n_band, //n
108108
this->n_dim, //k
109-
this->one,
109+
this->one, //1.0
110110
psi_out.data<T>(),
111111
this->n_basis, //lda
112112
psi_out.data<T>(),
113113
this->n_basis, //ldb
114-
this->zero,
114+
this->zero, //0.0
115115
hsub_out.data<T>(),
116116
this->n_band); //ldc
117117

@@ -175,12 +175,12 @@ void DiagoBPCG<T, Device>::orth_projection(
175175
this->n_band, //m
176176
this->n_band, //n
177177
this->n_dim, //k
178-
this->one,
178+
this->one, //1.0
179179
psi_in.data<T>(),
180180
this->n_basis, //lda
181181
grad_out.data<T>(),
182182
this->n_basis, //ldb
183-
this->zero,
183+
this->zero, //0.0
184184
hsub_in.data<T>(),
185185
this->n_band); //ldc
186186

@@ -189,21 +189,21 @@ void DiagoBPCG<T, Device>::orth_projection(
189189
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
190190
// grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191191

192-
// grad_out(n_basis x n_band) = psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
192+
// grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193193
gemm_op()(this->ctx,
194194
'N',
195195
'N',
196-
this->n_dim, //m
196+
this->n_dim, //m
197197
this->n_band, //n
198198
this->n_band, //k
199-
this->neg_one,
199+
this->neg_one, //-1.0
200200
psi_in.data<T>(),
201201
this->n_basis, //lda
202202
hsub_in.data<T>(),
203203
this->n_band, //ldb
204-
this->one,
204+
this->one, //1.0
205205
grad_out.data<T>(),
206-
this->n_basis); //ldc
206+
this->n_basis); //ldc
207207

208208
return;
209209
}
@@ -224,15 +224,15 @@ void DiagoBPCG<T, Device>::rotate_wf(
224224
gemm_op()(this->ctx,
225225
'N',
226226
'N',
227-
this->n_dim, //m
228-
this->n_band, //n
227+
this->n_dim, //m
228+
this->n_band, //n
229229
this->n_band, //k
230-
this->one,
230+
this->one, //1.0
231231
psi_out.data<T>(),
232-
this->n_basis, //lda
232+
this->n_basis, //lda
233233
hsub_in.data<T>(),
234-
this->n_band, //ldb
235-
this->zero,
234+
this->n_band, //ldb
235+
this->zero, //0.0
236236
workspace_in.data<T>(),
237237
this->n_basis); //ldc
238238

@@ -263,7 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
263263
// it controls the ops to use the corresponding device to calculate results
264264
ct::EinsumOption option(
265265
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
266-
hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
266+
// hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
267+
268+
// gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
269+
gemm_op()(this->ctx,
270+
'C',
271+
'N',
272+
this->n_band, //m
273+
this->n_band, //n
274+
this->n_dim, //k
275+
this->one, //1.0
276+
hpsi_in.data<T>(),
277+
this->n_basis, //lda
278+
psi_in.data<T>(),
279+
this->n_basis, //ldb
280+
this->zero, //0.0
281+
hsub_out.data<T>(),
282+
this->n_band); //ldc
267283

268284
ct::kernels::lapack_dnevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
269285

0 commit comments

Comments
 (0)