|
13 | 13 | #include <memory> |
14 | 14 | #include <utility> |
15 | 15 |
|
16 | | -extern "C" void ztrtri_(char *uplo, char *diag, int *n, std::complex<double> *a, int *lda, int *info); |
| 16 | +extern "C" |
| 17 | +{ |
| 18 | + void ztrtri_(char *uplo, char *diag, int *n, std::complex<double> *a, int *lda, int *info); |
| 19 | + void ctrtri_(char *uplo, char *diag, int *n, std::complex<float> *a, int *lda, int *info); |
| 20 | +} |
| 21 | + |
17 | 22 | //extern "C" void zpotrf_(char* uplo, const int* n, std::complex<double>* A, const int* lda, int* info); |
18 | 23 | //extern "C" void cpotrf_(char* uplo, const int* n, std::complex<float>* A, const int* lda, int* info); |
19 | 24 |
|
@@ -265,19 +270,10 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands, |
265 | 270 | const int ngk_ik, |
266 | 271 | const bool is_first_node) const |
267 | 272 | { |
268 | | - if (nbasis != p_exx_helper->psi.get_nbasis()) |
269 | | - { |
270 | | -// ModuleBase::ERROR_QUIT("OperatorEXXPW", "nbasis != psi.get_nbasis()"); |
271 | | -// std::cout << "nbasis != psi.get_nbasis()" << std::endl; |
272 | | -// std::cout << "nbasis: " << nbasis << std::endl; |
273 | | -// std::cout << "psi.get_nbasis(): " << p_exx_helper->psi.get_nbasis() << std::endl; |
274 | | -// std::cout << "npwk_max: " << wfcpw->npwk_max << std::endl; |
275 | | -// std::cout << "npwk_ik: " << wfcpw->npwk[this->ik] << std::endl; |
276 | | -// throw std::runtime_error("nbasis != psi.get_nbasis()"); |
277 | | - } |
278 | 273 | // std::cout << "act_op_ace" << std::endl; |
279 | 274 | // hpsi += -Xi^\dagger * Xi * psi |
280 | | - int nbands_tot = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk(); |
| 275 | + auto Xi_ace = Xi_ace_k[this->ik]; |
| 276 | + int nbands_tot = p_exx_helper->psi.get_nbands(); |
281 | 277 | int nbasis_max = p_exx_helper->psi.get_nbasis(); |
282 | 278 | // T* hpsi = nullptr; |
283 | 279 | // resmem_complex_op()(hpsi, nbands_tot * nbasis); |
@@ -334,178 +330,195 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands, |
334 | 330 | template <typename T, typename Device> |
335 | 331 | void OperatorEXXPW<T, Device>::construct_ace() const |
336 | 332 | { |
337 | | - int nbands_tot = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk(); |
| 333 | +// int nkb = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk(); |
| 334 | + int nbands = p_exx_helper->psi.get_nbands(); |
338 | 335 | int nbasis = p_exx_helper->psi.get_nbasis(); |
| 336 | + int nk = p_exx_helper->psi.get_nk(); |
| 337 | + |
| 338 | + T intermediate_one = 1.0, intermediate_zero = 0.0; |
| 339 | + |
339 | 340 | if (h_psi_ace == nullptr) |
340 | 341 | { |
341 | | - resmem_complex_op()(h_psi_ace, nbands_tot * nbasis); |
342 | | - setmem_complex_op()(h_psi_ace, 0, nbands_tot * nbasis); |
| 342 | + resmem_complex_op()(h_psi_ace, nbands * nbasis); |
| 343 | + setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); |
343 | 344 | } |
344 | 345 |
|
345 | | - if (Xi_ace == nullptr) |
| 346 | + if (Xi_ace_k.size() != nk) |
346 | 347 | { |
347 | | - resmem_complex_op()(Xi_ace, nbands_tot * nbasis); |
| 348 | + Xi_ace_k.resize(nk); |
| 349 | + for (int i = 0; i < nk; i++) |
| 350 | + { |
| 351 | + resmem_complex_op()(Xi_ace_k[i], nbands * nbasis); |
| 352 | + } |
| 353 | + } |
| 354 | + |
| 355 | + for (int i = 0; i < nk; i++) |
| 356 | + { |
| 357 | + setmem_complex_op()(Xi_ace_k[i], 0, nbands * nbasis); |
348 | 358 | } |
349 | 359 |
|
350 | 360 | if (L_ace == nullptr) |
351 | 361 | { |
352 | | - resmem_complex_op()(L_ace, nbands_tot * nbands_tot); |
353 | | - setmem_complex_op()(L_ace, 0, nbands_tot * nbands_tot); |
| 362 | + resmem_complex_op()(L_ace, nbands * nbands); |
| 363 | + setmem_complex_op()(L_ace, 0, nbands * nbands); |
354 | 364 | } |
355 | 365 |
|
356 | 366 | if (psi_h_psi_ace == nullptr) |
357 | 367 | { |
358 | | - resmem_complex_op()(psi_h_psi_ace, nbands_tot * nbands_tot); |
| 368 | + resmem_complex_op()(psi_h_psi_ace, nbands * nbands); |
359 | 369 | } |
360 | 370 |
|
361 | 371 | // std::ofstream ofs_psi("psi.dat", std::ios::binary); |
362 | 372 | // p_exx_helper->psi.fix_kb(0, 0); |
363 | | -// ofs_psi.write(reinterpret_cast<char*>(p_exx_helper->psi.get_pointer()), nbands_tot * nbasis * sizeof(T)); |
| 373 | +// ofs_psi.write(reinterpret_cast<char*>(p_exx_helper->psi.get_pointer()), nkb * nbasis * sizeof(T)); |
364 | 374 | // ofs_psi.close(); |
365 | 375 |
|
366 | 376 | for (int ik = 0; ik < p_exx_helper->psi.get_nk(); ik++) |
367 | 377 | { |
| 378 | + auto Xi_ace = Xi_ace_k[ik]; |
| 379 | + |
368 | 380 | p_exx_helper->psi.fix_kb(ik, 0); |
369 | 381 | T* p_psi = p_exx_helper->psi.get_pointer(); |
370 | 382 | act_op( |
371 | 383 | p_exx_helper->psi.get_nbands(), |
372 | 384 | nbasis, |
373 | 385 | 1, |
374 | 386 | p_psi, |
375 | | - h_psi_ace + ik * p_exx_helper->psi.get_nbands() * nbasis, |
| 387 | + h_psi_ace, |
376 | 388 | ik, |
377 | 389 | false |
378 | 390 | ); |
379 | | - } |
380 | | - |
381 | | -// // save h_psi_ace to disk |
382 | | -// std::ofstream ofs_hpsi("hpsi.dat", std::ios::binary); |
383 | | -// ofs_hpsi.write(reinterpret_cast<char*>(h_psi_ace), nbands_tot * nbasis * sizeof(T)); |
384 | | -// ofs_hpsi.close(); |
385 | 391 |
|
386 | | - T intermediate_one = 1.0, intermediate_zero = 0.0; |
387 | | - |
388 | | - // psi_h_psi_ace = psi^\dagger * h_psi_ace |
389 | | - p_exx_helper->psi.fix_kb(0, 0); |
390 | | - gemm_complex_op()(this->ctx, |
391 | | - 'C', |
392 | | - 'N', |
393 | | - nbands_tot, |
394 | | - nbands_tot, |
395 | | - nbasis, |
396 | | - &intermediate_one, |
397 | | - p_exx_helper->psi.get_pointer(), |
398 | | - nbasis, |
399 | | - h_psi_ace, |
400 | | - nbasis, |
401 | | - &intermediate_zero, |
402 | | - psi_h_psi_ace, |
403 | | - nbands_tot |
404 | | - ); |
| 392 | + // psi_h_psi_ace = psi^\dagger * h_psi_ace |
| 393 | + p_exx_helper->psi.fix_kb(0, 0); |
| 394 | + gemm_complex_op()(this->ctx, |
| 395 | + 'C', |
| 396 | + 'N', |
| 397 | + nbands, |
| 398 | + nbands, |
| 399 | + nbasis, |
| 400 | + &intermediate_one, |
| 401 | + p_exx_helper->psi.get_pointer(), |
| 402 | + nbasis, |
| 403 | + h_psi_ace, |
| 404 | + nbasis, |
| 405 | + &intermediate_zero, |
| 406 | + psi_h_psi_ace, |
| 407 | + nbands); |
| 408 | + |
| 409 | + // reduction of psi_h_psi_ace, due to distributed memory |
| 410 | + Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands); |
| 411 | + |
| 412 | + // // save psi_h_psi_ace to disk |
| 413 | + // std::ofstream ofs_psi_hpsi("psihpsi.dat", std::ios::binary); |
| 414 | + // ofs_psi_hpsi.write(reinterpret_cast<char*>(psi_h_psi_ace), nkb * nkb * sizeof(T)); |
| 415 | + // ofs_psi_hpsi.close(); |
| 416 | + |
| 417 | +// L_ace = cholesky(-psi_h_psi_ace) |
| 418 | +#ifdef _OPENMP |
| 419 | +#pragma omp parallel for schedule(static) |
| 420 | +#endif |
| 421 | + for (int i = 0; i < nbands; i++) |
| 422 | + { |
| 423 | + for (int j = 0; j < nbands; j++) |
| 424 | + { |
| 425 | + L_ace[i * nbands + j] = -psi_h_psi_ace[i * nbands + j]; |
| 426 | + } |
| 427 | + } |
405 | 428 |
|
406 | | - // reduction of psi_h_psi_ace, due to distributed memory |
407 | | - Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands_tot * nbands_tot); |
| 429 | + int info = 0; |
| 430 | + char up = 'U', lo = 'L'; |
408 | 431 |
|
409 | | -// // save psi_h_psi_ace to disk |
410 | | -// std::ofstream ofs_psi_hpsi("psihpsi.dat", std::ios::binary); |
411 | | -// ofs_psi_hpsi.write(reinterpret_cast<char*>(psi_h_psi_ace), nbands_tot * nbands_tot * sizeof(T)); |
412 | | -// ofs_psi_hpsi.close(); |
| 432 | + if constexpr (std::is_same<T, std::complex<float>>::value) |
| 433 | + { |
| 434 | + cpotrf_(&lo, &nbands, L_ace, &nbands, &info); |
| 435 | + } |
| 436 | + else if constexpr (std::is_same<T, std::complex<double>>::value) |
| 437 | + { |
| 438 | + zpotrf_(&lo, &nbands, L_ace, &nbands, &info); |
| 439 | + } |
413 | 440 |
|
414 | | - // L_ace = cholesky(-psi_h_psi_ace) |
415 | | - #ifdef _OPENMP |
416 | | - #pragma omp parallel for schedule(static) |
417 | | - #endif |
418 | | - for (int i = 0; i < nbands_tot; i++) |
419 | | - { |
420 | | - for (int j = 0; j < nbands_tot; j++) |
| 441 | +// expand for-loop |
| 442 | +#ifdef _OPENMP |
| 443 | +#pragma omp parallel for schedule(static) |
| 444 | +#endif |
| 445 | + for (int i = 0; i < nbands; i++) |
421 | 446 | { |
422 | | - L_ace[i * nbands_tot + j] = -psi_h_psi_ace[i * nbands_tot + j]; |
| 447 | + for (int j = 0; j < nbands; j++) |
| 448 | + { |
| 449 | + if (j < i) |
| 450 | + { |
| 451 | + // L_ace[j * nkb + i] = std::conj(L_ace[i * nkb + j]); |
| 452 | + L_ace[i * nbands + j] = 0.0; |
| 453 | + } |
| 454 | + } |
423 | 455 | } |
424 | | - } |
425 | 456 |
|
426 | | - int info = 0; |
427 | | - char up = 'U', lo = 'L'; |
| 457 | + // // save L_ace to disk |
| 458 | + // std::ofstream ofs_L("L.dat", std::ios::binary); |
| 459 | + // ofs_L.write(reinterpret_cast<char*>(L_ace), nkb * nkb * sizeof(T)); |
| 460 | + // ofs_L.close(); |
428 | 461 |
|
429 | | - if constexpr (std::is_same<T, std::complex<float>>::value) |
430 | | - { |
431 | | - cpotrf_(&lo, &nbands_tot, L_ace, &nbands_tot, &info); |
432 | | - } |
433 | | - else if constexpr (std::is_same<T, std::complex<double>>::value) |
434 | | - { |
435 | | - zpotrf_(&lo, &nbands_tot, L_ace, &nbands_tot, &info); |
436 | | - } |
| 462 | + // L_ace inv in place |
| 463 | + // T == std::complex<float> or std::complex<double> |
| 464 | + if constexpr (std::is_same<T, std::complex<float>>::value) |
| 465 | + { |
| 466 | + char non_unitary = 'N'; |
437 | 467 |
|
438 | | - // expand for-loop |
439 | | - #ifdef _OPENMP |
440 | | - #pragma omp parallel for schedule(static) |
441 | | - #endif |
442 | | - for (int i = 0; i < nbands_tot; i++) |
443 | | - { |
444 | | - for (int j = 0; j < nbands_tot; j++) |
| 468 | + ctrtri_(&lo, &non_unitary, &nbands, L_ace, &nbands, &info); |
| 469 | + } |
| 470 | + else if constexpr (std::is_same<T, std::complex<double>>::value) |
445 | 471 | { |
446 | | - if (j < i) |
447 | | - { |
448 | | -// L_ace[j * nbands_tot + i] = std::conj(L_ace[i * nbands_tot + j]); |
449 | | - L_ace[i * nbands_tot + j] = 0.0; |
450 | | - } |
| 472 | + char non_unitary = 'N'; |
| 473 | + |
| 474 | + ztrtri_(&lo, &non_unitary, &nbands, L_ace, &nbands, &info); |
451 | 475 | } |
452 | | - } |
453 | 476 |
|
454 | | -// // save L_ace to disk |
455 | | -// std::ofstream ofs_L("L.dat", std::ios::binary); |
456 | | -// ofs_L.write(reinterpret_cast<char*>(L_ace), nbands_tot * nbands_tot * sizeof(T)); |
457 | | -// ofs_L.close(); |
| 477 | + // // save L_ace inv to disk |
| 478 | + // std::ofstream ofs_L_inv("L_inv.dat", std::ios::binary); |
| 479 | + // ofs_L_inv.write(reinterpret_cast<char*>(L_ace), nkb * nkb * sizeof(T)); |
| 480 | + // ofs_L_inv.close(); |
| 481 | + |
| 482 | + // Xi_ace = L_ace^-1 * h_psi_ace^dagger |
| 483 | + gemm_complex_op()(this->ctx, |
| 484 | + 'N', |
| 485 | + 'C', |
| 486 | + nbands, |
| 487 | + nbasis, |
| 488 | + nbands, |
| 489 | + &intermediate_one, |
| 490 | + L_ace, |
| 491 | + nbands, |
| 492 | + h_psi_ace, |
| 493 | + nbasis, |
| 494 | + &intermediate_zero, |
| 495 | + Xi_ace, |
| 496 | + nbands); |
| 497 | + |
| 498 | + // // save Xi_ace to disk |
| 499 | + // std::ofstream ofs_Xi("Xi.dat", std::ios::binary); |
| 500 | + // ofs_Xi.write(reinterpret_cast<char*>(Xi_ace), nkb * nbasis * sizeof(T)); |
| 501 | + // ofs_Xi.close(); |
| 502 | + // |
| 503 | + // std::cout << "nkb: " << nkb << std::endl; |
| 504 | + // std::cout << "nbands: " << p_exx_helper->psi.get_nbands() << std::endl; |
| 505 | + // std::cout << "nbasis: " << nbasis << std::endl; |
| 506 | + |
| 507 | + // clear mem |
| 508 | + setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); |
| 509 | + // setmem_complex_op()(Xi_ace, 0, nkb * nbasis); |
| 510 | + setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands); |
| 511 | + setmem_complex_op()(L_ace, 0, nbands * nbands); |
458 | 512 |
|
459 | | - // L_ace inv in place |
460 | | - // T == std::complex<float> or std::complex<double> |
461 | | - if constexpr (std::is_same<T, std::complex<float>>::value) |
462 | | - { |
463 | | - // cgetrf_(&p_exx_helper->psi.get_nbands(), &p_exx_helper->psi.get_nbands(), L_ace, &p_exx_helper->psi.get_nbands(), ipiv.data(), &info); |
464 | | - // Todo: implement cgetrf and cgetri |
465 | 513 | } |
466 | | - else if constexpr (std::is_same<T, std::complex<double>>::value) |
467 | | - { |
468 | | - char non_unitary = 'N'; |
469 | 514 |
|
470 | | - ztrtri_(&lo, &non_unitary, &nbands_tot, L_ace, &nbands_tot, &info); |
471 | | - } |
| 515 | +// // save h_psi_ace to disk |
| 516 | +// std::ofstream ofs_hpsi("hpsi.dat", std::ios::binary); |
| 517 | +// ofs_hpsi.write(reinterpret_cast<char*>(h_psi_ace), nkb * nbasis * sizeof(T)); |
| 518 | +// ofs_hpsi.close(); |
| 519 | + |
| 520 | + |
472 | 521 |
|
473 | | -// // save L_ace inv to disk |
474 | | -// std::ofstream ofs_L_inv("L_inv.dat", std::ios::binary); |
475 | | -// ofs_L_inv.write(reinterpret_cast<char*>(L_ace), nbands_tot * nbands_tot * sizeof(T)); |
476 | | -// ofs_L_inv.close(); |
477 | | - |
478 | | - // Xi_ace = L_ace^-1 * h_psi_ace^dagger |
479 | | - gemm_complex_op()(this->ctx, |
480 | | - 'N', |
481 | | - 'C', |
482 | | - nbands_tot, |
483 | | - nbasis, |
484 | | - nbands_tot, |
485 | | - &intermediate_one, |
486 | | - L_ace, |
487 | | - nbands_tot, |
488 | | - h_psi_ace, |
489 | | - nbasis, |
490 | | - &intermediate_zero, |
491 | | - Xi_ace, |
492 | | - nbands_tot |
493 | | - ); |
494 | | - |
495 | | -// // save Xi_ace to disk |
496 | | -// std::ofstream ofs_Xi("Xi.dat", std::ios::binary); |
497 | | -// ofs_Xi.write(reinterpret_cast<char*>(Xi_ace), nbands_tot * nbasis * sizeof(T)); |
498 | | -// ofs_Xi.close(); |
499 | | -// |
500 | | -// std::cout << "nbands_tot: " << nbands_tot << std::endl; |
501 | | -// std::cout << "nbands: " << p_exx_helper->psi.get_nbands() << std::endl; |
502 | | -// std::cout << "nbasis: " << nbasis << std::endl; |
503 | | - |
504 | | - // clear mem |
505 | | - setmem_complex_op()(h_psi_ace, 0, nbands_tot * nbasis); |
506 | | -// setmem_complex_op()(Xi_ace, 0, nbands_tot * nbasis); |
507 | | - setmem_complex_op()(psi_h_psi_ace, 0, nbands_tot * nbands_tot); |
508 | | - setmem_complex_op()(L_ace, 0, nbands_tot * nbands_tot); |
509 | 522 |
|
510 | 523 | } |
511 | 524 |
|
|
0 commit comments