Skip to content

Commit 08d92dd

Browse files
Feature: ACE should work for multi-K, but not for sure
1 parent 2e742a2 commit 08d92dd

File tree

2 files changed

+151
-137
lines changed

2 files changed

+151
-137
lines changed

source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp

Lines changed: 149 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
#include <memory>
1414
#include <utility>
1515

16-
extern "C" void ztrtri_(char *uplo, char *diag, int *n, std::complex<double> *a, int *lda, int *info);
16+
extern "C"
17+
{
18+
void ztrtri_(char *uplo, char *diag, int *n, std::complex<double> *a, int *lda, int *info);
19+
void ctrtri_(char *uplo, char *diag, int *n, std::complex<float> *a, int *lda, int *info);
20+
}
21+
1722
//extern "C" void zpotrf_(char* uplo, const int* n, std::complex<double>* A, const int* lda, int* info);
1823
//extern "C" void cpotrf_(char* uplo, const int* n, std::complex<float>* A, const int* lda, int* info);
1924

@@ -265,19 +270,10 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
265270
const int ngk_ik,
266271
const bool is_first_node) const
267272
{
268-
if (nbasis != p_exx_helper->psi.get_nbasis())
269-
{
270-
// ModuleBase::ERROR_QUIT("OperatorEXXPW", "nbasis != psi.get_nbasis()");
271-
// std::cout << "nbasis != psi.get_nbasis()" << std::endl;
272-
// std::cout << "nbasis: " << nbasis << std::endl;
273-
// std::cout << "psi.get_nbasis(): " << p_exx_helper->psi.get_nbasis() << std::endl;
274-
// std::cout << "npwk_max: " << wfcpw->npwk_max << std::endl;
275-
// std::cout << "npwk_ik: " << wfcpw->npwk[this->ik] << std::endl;
276-
// throw std::runtime_error("nbasis != psi.get_nbasis()");
277-
}
278273
// std::cout << "act_op_ace" << std::endl;
279274
// hpsi += -Xi^\dagger * Xi * psi
280-
int nbands_tot = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk();
275+
auto Xi_ace = Xi_ace_k[this->ik];
276+
int nbands_tot = p_exx_helper->psi.get_nbands();
281277
int nbasis_max = p_exx_helper->psi.get_nbasis();
282278
// T* hpsi = nullptr;
283279
// resmem_complex_op()(hpsi, nbands_tot * nbasis);
@@ -334,178 +330,195 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
334330
template <typename T, typename Device>
335331
void OperatorEXXPW<T, Device>::construct_ace() const
336332
{
337-
int nbands_tot = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk();
333+
// int nkb = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk();
334+
int nbands = p_exx_helper->psi.get_nbands();
338335
int nbasis = p_exx_helper->psi.get_nbasis();
336+
int nk = p_exx_helper->psi.get_nk();
337+
338+
T intermediate_one = 1.0, intermediate_zero = 0.0;
339+
339340
if (h_psi_ace == nullptr)
340341
{
341-
resmem_complex_op()(h_psi_ace, nbands_tot * nbasis);
342-
setmem_complex_op()(h_psi_ace, 0, nbands_tot * nbasis);
342+
resmem_complex_op()(h_psi_ace, nbands * nbasis);
343+
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
343344
}
344345

345-
if (Xi_ace == nullptr)
346+
if (Xi_ace_k.size() != nk)
346347
{
347-
resmem_complex_op()(Xi_ace, nbands_tot * nbasis);
348+
Xi_ace_k.resize(nk);
349+
for (int i = 0; i < nk; i++)
350+
{
351+
resmem_complex_op()(Xi_ace_k[i], nbands * nbasis);
352+
}
353+
}
354+
355+
for (int i = 0; i < nk; i++)
356+
{
357+
setmem_complex_op()(Xi_ace_k[i], 0, nbands * nbasis);
348358
}
349359

350360
if (L_ace == nullptr)
351361
{
352-
resmem_complex_op()(L_ace, nbands_tot * nbands_tot);
353-
setmem_complex_op()(L_ace, 0, nbands_tot * nbands_tot);
362+
resmem_complex_op()(L_ace, nbands * nbands);
363+
setmem_complex_op()(L_ace, 0, nbands * nbands);
354364
}
355365

356366
if (psi_h_psi_ace == nullptr)
357367
{
358-
resmem_complex_op()(psi_h_psi_ace, nbands_tot * nbands_tot);
368+
resmem_complex_op()(psi_h_psi_ace, nbands * nbands);
359369
}
360370

361371
// std::ofstream ofs_psi("psi.dat", std::ios::binary);
362372
// p_exx_helper->psi.fix_kb(0, 0);
363-
// ofs_psi.write(reinterpret_cast<char*>(p_exx_helper->psi.get_pointer()), nbands_tot * nbasis * sizeof(T));
373+
// ofs_psi.write(reinterpret_cast<char*>(p_exx_helper->psi.get_pointer()), nkb * nbasis * sizeof(T));
364374
// ofs_psi.close();
365375

366376
for (int ik = 0; ik < p_exx_helper->psi.get_nk(); ik++)
367377
{
378+
auto Xi_ace = Xi_ace_k[ik];
379+
368380
p_exx_helper->psi.fix_kb(ik, 0);
369381
T* p_psi = p_exx_helper->psi.get_pointer();
370382
act_op(
371383
p_exx_helper->psi.get_nbands(),
372384
nbasis,
373385
1,
374386
p_psi,
375-
h_psi_ace + ik * p_exx_helper->psi.get_nbands() * nbasis,
387+
h_psi_ace,
376388
ik,
377389
false
378390
);
379-
}
380-
381-
// // save h_psi_ace to disk
382-
// std::ofstream ofs_hpsi("hpsi.dat", std::ios::binary);
383-
// ofs_hpsi.write(reinterpret_cast<char*>(h_psi_ace), nbands_tot * nbasis * sizeof(T));
384-
// ofs_hpsi.close();
385391

386-
T intermediate_one = 1.0, intermediate_zero = 0.0;
387-
388-
// psi_h_psi_ace = psi^\dagger * h_psi_ace
389-
p_exx_helper->psi.fix_kb(0, 0);
390-
gemm_complex_op()(this->ctx,
391-
'C',
392-
'N',
393-
nbands_tot,
394-
nbands_tot,
395-
nbasis,
396-
&intermediate_one,
397-
p_exx_helper->psi.get_pointer(),
398-
nbasis,
399-
h_psi_ace,
400-
nbasis,
401-
&intermediate_zero,
402-
psi_h_psi_ace,
403-
nbands_tot
404-
);
392+
// psi_h_psi_ace = psi^\dagger * h_psi_ace
393+
p_exx_helper->psi.fix_kb(0, 0);
394+
gemm_complex_op()(this->ctx,
395+
'C',
396+
'N',
397+
nbands,
398+
nbands,
399+
nbasis,
400+
&intermediate_one,
401+
p_exx_helper->psi.get_pointer(),
402+
nbasis,
403+
h_psi_ace,
404+
nbasis,
405+
&intermediate_zero,
406+
psi_h_psi_ace,
407+
nbands);
408+
409+
// reduction of psi_h_psi_ace, due to distributed memory
410+
Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands);
411+
412+
// // save psi_h_psi_ace to disk
413+
// std::ofstream ofs_psi_hpsi("psihpsi.dat", std::ios::binary);
414+
// ofs_psi_hpsi.write(reinterpret_cast<char*>(psi_h_psi_ace), nkb * nkb * sizeof(T));
415+
// ofs_psi_hpsi.close();
416+
417+
// L_ace = cholesky(-psi_h_psi_ace)
418+
#ifdef _OPENMP
419+
#pragma omp parallel for schedule(static)
420+
#endif
421+
for (int i = 0; i < nbands; i++)
422+
{
423+
for (int j = 0; j < nbands; j++)
424+
{
425+
L_ace[i * nbands + j] = -psi_h_psi_ace[i * nbands + j];
426+
}
427+
}
405428

406-
// reduction of psi_h_psi_ace, due to distributed memory
407-
Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands_tot * nbands_tot);
429+
int info = 0;
430+
char up = 'U', lo = 'L';
408431

409-
// // save psi_h_psi_ace to disk
410-
// std::ofstream ofs_psi_hpsi("psihpsi.dat", std::ios::binary);
411-
// ofs_psi_hpsi.write(reinterpret_cast<char*>(psi_h_psi_ace), nbands_tot * nbands_tot * sizeof(T));
412-
// ofs_psi_hpsi.close();
432+
if constexpr (std::is_same<T, std::complex<float>>::value)
433+
{
434+
cpotrf_(&lo, &nbands, L_ace, &nbands, &info);
435+
}
436+
else if constexpr (std::is_same<T, std::complex<double>>::value)
437+
{
438+
zpotrf_(&lo, &nbands, L_ace, &nbands, &info);
439+
}
413440

414-
// L_ace = cholesky(-psi_h_psi_ace)
415-
#ifdef _OPENMP
416-
#pragma omp parallel for schedule(static)
417-
#endif
418-
for (int i = 0; i < nbands_tot; i++)
419-
{
420-
for (int j = 0; j < nbands_tot; j++)
441+
// expand for-loop
442+
#ifdef _OPENMP
443+
#pragma omp parallel for schedule(static)
444+
#endif
445+
for (int i = 0; i < nbands; i++)
421446
{
422-
L_ace[i * nbands_tot + j] = -psi_h_psi_ace[i * nbands_tot + j];
447+
for (int j = 0; j < nbands; j++)
448+
{
449+
if (j < i)
450+
{
451+
// L_ace[j * nkb + i] = std::conj(L_ace[i * nkb + j]);
452+
L_ace[i * nbands + j] = 0.0;
453+
}
454+
}
423455
}
424-
}
425456

426-
int info = 0;
427-
char up = 'U', lo = 'L';
457+
// // save L_ace to disk
458+
// std::ofstream ofs_L("L.dat", std::ios::binary);
459+
// ofs_L.write(reinterpret_cast<char*>(L_ace), nkb * nkb * sizeof(T));
460+
// ofs_L.close();
428461

429-
if constexpr (std::is_same<T, std::complex<float>>::value)
430-
{
431-
cpotrf_(&lo, &nbands_tot, L_ace, &nbands_tot, &info);
432-
}
433-
else if constexpr (std::is_same<T, std::complex<double>>::value)
434-
{
435-
zpotrf_(&lo, &nbands_tot, L_ace, &nbands_tot, &info);
436-
}
462+
// L_ace inv in place
463+
// T == std::complex<float> or std::complex<double>
464+
if constexpr (std::is_same<T, std::complex<float>>::value)
465+
{
466+
char non_unitary = 'N';
437467

438-
// expand for-loop
439-
#ifdef _OPENMP
440-
#pragma omp parallel for schedule(static)
441-
#endif
442-
for (int i = 0; i < nbands_tot; i++)
443-
{
444-
for (int j = 0; j < nbands_tot; j++)
468+
ctrtri_(&lo, &non_unitary, &nbands, L_ace, &nbands, &info);
469+
}
470+
else if constexpr (std::is_same<T, std::complex<double>>::value)
445471
{
446-
if (j < i)
447-
{
448-
// L_ace[j * nbands_tot + i] = std::conj(L_ace[i * nbands_tot + j]);
449-
L_ace[i * nbands_tot + j] = 0.0;
450-
}
472+
char non_unitary = 'N';
473+
474+
ztrtri_(&lo, &non_unitary, &nbands, L_ace, &nbands, &info);
451475
}
452-
}
453476

454-
// // save L_ace to disk
455-
// std::ofstream ofs_L("L.dat", std::ios::binary);
456-
// ofs_L.write(reinterpret_cast<char*>(L_ace), nbands_tot * nbands_tot * sizeof(T));
457-
// ofs_L.close();
477+
// // save L_ace inv to disk
478+
// std::ofstream ofs_L_inv("L_inv.dat", std::ios::binary);
479+
// ofs_L_inv.write(reinterpret_cast<char*>(L_ace), nkb * nkb * sizeof(T));
480+
// ofs_L_inv.close();
481+
482+
// Xi_ace = L_ace^-1 * h_psi_ace^dagger
483+
gemm_complex_op()(this->ctx,
484+
'N',
485+
'C',
486+
nbands,
487+
nbasis,
488+
nbands,
489+
&intermediate_one,
490+
L_ace,
491+
nbands,
492+
h_psi_ace,
493+
nbasis,
494+
&intermediate_zero,
495+
Xi_ace,
496+
nbands);
497+
498+
// // save Xi_ace to disk
499+
// std::ofstream ofs_Xi("Xi.dat", std::ios::binary);
500+
// ofs_Xi.write(reinterpret_cast<char*>(Xi_ace), nkb * nbasis * sizeof(T));
501+
// ofs_Xi.close();
502+
//
503+
// std::cout << "nkb: " << nkb << std::endl;
504+
// std::cout << "nbands: " << p_exx_helper->psi.get_nbands() << std::endl;
505+
// std::cout << "nbasis: " << nbasis << std::endl;
506+
507+
// clear mem
508+
setmem_complex_op()(h_psi_ace, 0, nbands * nbasis);
509+
// setmem_complex_op()(Xi_ace, 0, nkb * nbasis);
510+
setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands);
511+
setmem_complex_op()(L_ace, 0, nbands * nbands);
458512

459-
// L_ace inv in place
460-
// T == std::complex<float> or std::complex<double>
461-
if constexpr (std::is_same<T, std::complex<float>>::value)
462-
{
463-
// cgetrf_(&p_exx_helper->psi.get_nbands(), &p_exx_helper->psi.get_nbands(), L_ace, &p_exx_helper->psi.get_nbands(), ipiv.data(), &info);
464-
// Todo: implement cgetrf and cgetri
465513
}
466-
else if constexpr (std::is_same<T, std::complex<double>>::value)
467-
{
468-
char non_unitary = 'N';
469514

470-
ztrtri_(&lo, &non_unitary, &nbands_tot, L_ace, &nbands_tot, &info);
471-
}
515+
// // save h_psi_ace to disk
516+
// std::ofstream ofs_hpsi("hpsi.dat", std::ios::binary);
517+
// ofs_hpsi.write(reinterpret_cast<char*>(h_psi_ace), nkb * nbasis * sizeof(T));
518+
// ofs_hpsi.close();
519+
520+
472521

473-
// // save L_ace inv to disk
474-
// std::ofstream ofs_L_inv("L_inv.dat", std::ios::binary);
475-
// ofs_L_inv.write(reinterpret_cast<char*>(L_ace), nbands_tot * nbands_tot * sizeof(T));
476-
// ofs_L_inv.close();
477-
478-
// Xi_ace = L_ace^-1 * h_psi_ace^dagger
479-
gemm_complex_op()(this->ctx,
480-
'N',
481-
'C',
482-
nbands_tot,
483-
nbasis,
484-
nbands_tot,
485-
&intermediate_one,
486-
L_ace,
487-
nbands_tot,
488-
h_psi_ace,
489-
nbasis,
490-
&intermediate_zero,
491-
Xi_ace,
492-
nbands_tot
493-
);
494-
495-
// // save Xi_ace to disk
496-
// std::ofstream ofs_Xi("Xi.dat", std::ios::binary);
497-
// ofs_Xi.write(reinterpret_cast<char*>(Xi_ace), nbands_tot * nbasis * sizeof(T));
498-
// ofs_Xi.close();
499-
//
500-
// std::cout << "nbands_tot: " << nbands_tot << std::endl;
501-
// std::cout << "nbands: " << p_exx_helper->psi.get_nbands() << std::endl;
502-
// std::cout << "nbasis: " << nbasis << std::endl;
503-
504-
// clear mem
505-
setmem_complex_op()(h_psi_ace, 0, nbands_tot * nbasis);
506-
// setmem_complex_op()(Xi_ace, 0, nbands_tot * nbasis);
507-
setmem_complex_op()(psi_h_psi_ace, 0, nbands_tot * nbands_tot);
508-
setmem_complex_op()(L_ace, 0, nbands_tot * nbands_tot);
509522

510523
}
511524

source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ class OperatorEXXPW : public OperatorPW<T, Device>
117117
mutable T* h_psi_ace = nullptr; // H \Psi, W in the paper
118118
mutable T* psi_h_psi_ace = nullptr; // \Psi^{\dagger} H \Psi, M in the paper
119119
mutable T* L_ace = nullptr; // cholesky(-M).L, L in the paper
120-
mutable T* Xi_ace = nullptr; // L^{-1} (H \Psi)^{\dagger}, \Xi in the paper
120+
mutable std::vector<T*> Xi_ace_k; // L^{-1} (H \Psi)^{\dagger}, \Xi in the paper
121+
// mutable T* Xi_ace = nullptr; // L^{-1} (H \Psi)^{\dagger}, \Xi in the paper
121122

122123
bool ace = true;
123124

0 commit comments

Comments
 (0)