Skip to content

Commit a3350b7

Browse files
committed
Revert gemm substitute for einsum
1 parent 830188a commit a3350b7

File tree

1 file changed

+60
-60
lines changed

1 file changed

+60
-60
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,23 @@ void DiagoBPCG<T, Device>::orth_cholesky(
9797
// hsub_out = psi_out * transc(psi_out)
9898
ct::EinsumOption option(
9999
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
100-
// hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
100+
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
101101

102102
// gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band)
103-
gemm_op()(this->ctx,
104-
'C',
105-
'N',
106-
this->n_band, //m
107-
this->n_band, //n
108-
this->n_dim, //k
109-
this->one, //1.0
110-
psi_out.data<T>(),
111-
this->n_basis, //lda
112-
psi_out.data<T>(),
113-
this->n_basis, //ldb
114-
this->zero, //0.0
115-
hsub_out.data<T>(),
116-
this->n_band); //ldc
103+
// gemm_op()(this->ctx,
104+
// 'C',
105+
// 'N',
106+
// this->n_band, //m
107+
// this->n_band, //n
108+
// this->n_dim, //k
109+
// this->one, //1.0
110+
// psi_out.data<T>(),
111+
// this->n_basis, //lda
112+
// psi_out.data<T>(),
113+
// this->n_basis, //ldb
114+
// this->zero, //0.0
115+
// hsub_out.data<T>(),
116+
// this->n_band); //ldc
117117

118118
// set hsub matrix to lower format;
119119
ct::kernels::set_matrix<T, ct_Device>()(
@@ -165,45 +165,45 @@ void DiagoBPCG<T, Device>::orth_projection(
165165
{
166166
ct::EinsumOption option(
167167
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
168-
// hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
168+
hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169169

170170
// this->orth_projection(this->psi, this->hsub, this->grad);
171171
// gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
172-
gemm_op()(this->ctx,
173-
'C',
174-
'N',
175-
this->n_band, //m
176-
this->n_band, //n
177-
this->n_dim, //k
178-
this->one, //1.0
179-
psi_in.data<T>(),
180-
this->n_basis, //lda
181-
grad_out.data<T>(),
182-
this->n_basis, //ldb
183-
this->zero, //0.0
184-
hsub_in.data<T>(),
185-
this->n_band); //ldc
172+
// gemm_op()(this->ctx,
173+
// 'C',
174+
// 'N',
175+
// this->n_band, //m
176+
// this->n_band, //n
177+
// this->n_dim, //k
178+
// this->one, //1.0
179+
// psi_in.data<T>(),
180+
// this->n_basis, //lda
181+
// grad_out.data<T>(),
182+
// this->n_basis, //ldb
183+
// this->zero, //0.0
184+
// hsub_in.data<T>(),
185+
// this->n_band); //ldc
186186

187187
// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
188188
option = ct::EinsumOption(
189189
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
190-
// grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
190+
grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option);
191191

192192
// grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band)
193-
gemm_op()(this->ctx,
194-
'N',
195-
'N',
196-
this->n_dim, //m
197-
this->n_band, //n
198-
this->n_band, //k
199-
this->neg_one, //-1.0
200-
psi_in.data<T>(),
201-
this->n_basis, //lda
202-
hsub_in.data<T>(),
203-
this->n_band, //ldb
204-
this->one, //1.0
205-
grad_out.data<T>(),
206-
this->n_basis); //ldc
193+
// gemm_op()(this->ctx,
194+
// 'N',
195+
// 'N',
196+
// this->n_dim, //m
197+
// this->n_band, //n
198+
// this->n_band, //k
199+
// this->neg_one, //-1.0
200+
// psi_in.data<T>(),
201+
// this->n_basis, //lda
202+
// hsub_in.data<T>(),
203+
// this->n_band, //ldb
204+
// this->one, //1.0
205+
// grad_out.data<T>(),
206+
// this->n_basis); //ldc
207207

208208
return;
209209
}
@@ -263,23 +263,23 @@ void DiagoBPCG<T, Device>::diag_hsub(
263263
// it controls the ops to use the corresponding device to calculate results
264264
ct::EinsumOption option(
265265
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
266-
// hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
266+
hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option);
267267

268268
// gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band)
269-
gemm_op()(this->ctx,
270-
'C',
271-
'N',
272-
this->n_band, //m
273-
this->n_band, //n
274-
this->n_dim, //k
275-
this->one, //1.0
276-
hpsi_in.data<T>(),
277-
this->n_basis, //lda
278-
psi_in.data<T>(),
279-
this->n_basis, //ldb
280-
this->zero, //0.0
281-
hsub_out.data<T>(),
282-
this->n_band); //ldc
269+
// gemm_op()(this->ctx,
270+
// 'C',
271+
// 'N',
272+
// this->n_band, //m
273+
// this->n_band, //n
274+
// this->n_dim, //k
275+
// this->one, //1.0
276+
// hpsi_in.data<T>(),
277+
// this->n_basis, //lda
278+
// psi_in.data<T>(),
279+
// this->n_basis, //ldb
280+
// this->zero, //0.0
281+
// hsub_out.data<T>(),
282+
// this->n_band); //ldc
283283

284284
ct::kernels::lapack_dnevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
285285

0 commit comments

Comments
 (0)