Skip to content

Commit 5fb93d2

Browse files
committed
OpenMP optimization for local potential in force calculation (cal_fvl_dphi() --> gamma_force()) under gamma-only line
1 parent 7d26f9c commit 5fb93d2

File tree

1 file changed

+187
-117
lines changed

1 file changed

+187
-117
lines changed

source/src_lcao/gint_gamma_fvl.cpp

Lines changed: 187 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -510,142 +510,212 @@ void Gint_Gamma::gamma_force(const double*const vlocal) const
510510
DGridV_z[i] = &DGridV_pool[i*GlobalC::GridT.lgd+2*DGridV_Size];
511511
}
512512
ModuleBase::Memory::record("Gint_Gamma","DGridV",3*GlobalC::GridT.lgd*GlobalC::GridT.lgd,"double");
513-
//OUT(GlobalV::ofs_running,"DGridV was allocated");
513+
#ifdef _OPENMP
514+
#pragma omp parallel
515+
{
516+
double *DGridV_pool_thread=new double[3*DGridV_Size];
517+
ModuleBase::GlobalFunc::ZEROS(DGridV_pool_thread, 3*DGridV_Size);
518+
519+
double** DGridV_x_thread = new double*[GlobalC::GridT.lgd];
520+
double** DGridV_y_thread = new double*[GlobalC::GridT.lgd];
521+
double** DGridV_z_thread = new double*[GlobalC::GridT.lgd];
522+
double* DGridV_stress_pool_thread;
523+
double** DGridV_11_thread;
524+
double** DGridV_12_thread;
525+
double** DGridV_13_thread;
526+
double** DGridV_22_thread;
527+
double** DGridV_23_thread;
528+
double** DGridV_33_thread;
514529

515-
// it's a uniform grid to save orbital values, so the delta_r is a constant.
516-
const double delta_r = GlobalC::ORB.dr_uniform;
530+
if(GlobalV::STRESS)
531+
{
532+
DGridV_stress_pool_thread = new double[6*DGridV_Size];
533+
ModuleBase::GlobalFunc::ZEROS(DGridV_stress_pool_thread, 6*DGridV_Size);
534+
DGridV_11_thread = new double*[GlobalC::GridT.lgd];
535+
DGridV_12_thread = new double*[GlobalC::GridT.lgd];
536+
DGridV_13_thread = new double*[GlobalC::GridT.lgd];
537+
DGridV_22_thread = new double*[GlobalC::GridT.lgd];
538+
DGridV_23_thread = new double*[GlobalC::GridT.lgd];
539+
DGridV_33_thread = new double*[GlobalC::GridT.lgd];
540+
541+
for(int i=0; i<GlobalC::GridT.lgd; ++i)
542+
{
543+
DGridV_11_thread[i] = &DGridV_stress_pool_thread[i*GlobalC::GridT.lgd];
544+
DGridV_12_thread[i] = &DGridV_stress_pool_thread[i*GlobalC::GridT.lgd+DGridV_Size];
545+
DGridV_13_thread[i] = &DGridV_stress_pool_thread[i*GlobalC::GridT.lgd+2*DGridV_Size];
546+
DGridV_22_thread[i] = &DGridV_stress_pool_thread[i*GlobalC::GridT.lgd+3*DGridV_Size];
547+
DGridV_23_thread[i] = &DGridV_stress_pool_thread[i*GlobalC::GridT.lgd+4*DGridV_Size];
548+
DGridV_33_thread[i] = &DGridV_stress_pool_thread[i*GlobalC::GridT.lgd+5*DGridV_Size];
549+
}
550+
ModuleBase::Memory::record("Gint_Gamma","DGridV_stress_thread",6*GlobalC::GridT.lgd*GlobalC::GridT.lgd,"double");
551+
}
517552

518-
int LD_pool=max_size*GlobalC::ucell.nwmax;
519-
double* dphi_pool;
520-
521-
double** dphix;
522-
double** dphiy;
523-
double** dphiz;
553+
for(int i=0; i<GlobalC::GridT.lgd; ++i)
554+
{
555+
DGridV_x_thread[i] = &DGridV_pool_thread[i*GlobalC::GridT.lgd];
556+
DGridV_y_thread[i] = &DGridV_pool_thread[i*GlobalC::GridT.lgd+DGridV_Size];
557+
DGridV_z_thread[i] = &DGridV_pool_thread[i*GlobalC::GridT.lgd+2*DGridV_Size];
558+
}
559+
ModuleBase::Memory::record("Gint_Gamma","DGridV_thread",3*GlobalC::GridT.lgd*GlobalC::GridT.lgd,"double");
560+
#endif
561+
// it's a uniform grid to save orbital values, so the delta_r is a constant.
562+
const double delta_r = GlobalC::ORB.dr_uniform;
524563

525-
bool** cal_flag;
564+
int LD_pool=max_size*GlobalC::ucell.nwmax;
565+
double* dphi_pool;
526566

527-
const int ncyz=GlobalC::pw.ncy*GlobalC::pw.nczp;
567+
double** dphix;
568+
double** dphiy;
569+
double** dphiz;
528570

529-
/* if(max_size<=0 || GlobalC::GridT.lgd <= 0)
530-
{
531-
//OUT(GlobalV::ofs_running,"max_size", max_size);
532-
//OUT(GlobalV::ofs_running,"GlobalC::GridT.lgd", GlobalC::GridT.lgd);
533-
goto ENDandRETURN;
534-
}*/
535-
if(max_size>0 && GlobalC::GridT.lgd > 0)
536-
{
537-
dphi_pool=new double [3*GlobalC::pw.bxyz*LD_pool];
538-
ModuleBase::GlobalFunc::ZEROS(dphi_pool, 3*GlobalC::pw.bxyz*LD_pool);
539-
dphix = new double*[GlobalC::pw.bxyz];
540-
dphiy = new double*[GlobalC::pw.bxyz];
541-
dphiz = new double*[GlobalC::pw.bxyz];
542-
543-
cal_flag=new bool*[GlobalC::pw.bxyz];
544-
for(int i=0; i<GlobalC::pw.bxyz; i++)
545-
{
546-
dphix[i] = &dphi_pool[i*LD_pool];
547-
dphiy[i] = &dphi_pool[i*LD_pool+GlobalC::pw.bxyz*LD_pool];
548-
dphiz[i] = &dphi_pool[i*LD_pool+2*GlobalC::pw.bxyz*LD_pool];
549-
cal_flag[i] = new bool[max_size];
550-
}
571+
bool** cal_flag;
572+
const int ncyz=GlobalC::pw.ncy*GlobalC::pw.nczp;
551573

552-
ModuleBase::realArray drr;//rewrite drr form by zhengdy-2019-04-02
553-
if(GlobalV::STRESS)
574+
if(max_size>0 && GlobalC::GridT.lgd > 0)
554575
{
555-
drr.create(max_size, GlobalC::pw.bxyz, 3);
556-
drr.zero_out();
557-
}
558-
/* double ***drr;//store dr for stress calculate, added by zhengdy
559-
if(GlobalV::STRESS)//added by zhengdy-stress
560-
{
561-
drr = new double**[max_size];
562-
for(int id=0; id<max_size; id++)
563-
{
564-
drr[id] = new double*[GlobalC::pw.bxyz];
565-
for(int ib=0; ib<GlobalC::pw.bxyz; ib++)
566-
{
567-
drr[id][ib] = new double[3];
568-
ModuleBase::GlobalFunc::ZEROS(drr[id][ib],3);
569-
}
570-
}
571-
}*/
572-
//OUT(GlobalV::ofs_running,"Data were prepared");
573-
//ModuleBase::timer::tick("Gint_Gamma","prepare");
574-
for (int i=0; i< GlobalC::GridT.nbx; i++)
575-
{
576-
const int ibx = i*GlobalC::pw.bx;
577-
578-
for (int j=0; j< GlobalC::GridT.nby; j++)
576+
dphi_pool=new double [3*GlobalC::pw.bxyz*LD_pool];
577+
ModuleBase::GlobalFunc::ZEROS(dphi_pool, 3*GlobalC::pw.bxyz*LD_pool);
578+
dphix = new double*[GlobalC::pw.bxyz];
579+
dphiy = new double*[GlobalC::pw.bxyz];
580+
dphiz = new double*[GlobalC::pw.bxyz];
581+
582+
cal_flag=new bool*[GlobalC::pw.bxyz];
583+
for(int i=0; i<GlobalC::pw.bxyz; i++)
579584
{
580-
const int jby = j*GlobalC::pw.by;
585+
dphix[i] = &dphi_pool[i*LD_pool];
586+
dphiy[i] = &dphi_pool[i*LD_pool+GlobalC::pw.bxyz*LD_pool];
587+
dphiz[i] = &dphi_pool[i*LD_pool+2*GlobalC::pw.bxyz*LD_pool];
588+
cal_flag[i] = new bool[max_size];
589+
}
581590

582-
for (int k= GlobalC::GridT.nbzp_start; k< GlobalC::GridT.nbzp_start+GlobalC::GridT.nbzp; k++)
591+
ModuleBase::realArray drr;//rewrite drr form by zhengdy-2019-04-02
592+
if(GlobalV::STRESS)
593+
{
594+
drr.create(max_size, GlobalC::pw.bxyz, 3);
595+
drr.zero_out();
596+
}
597+
#ifdef _OPENMP
598+
//#pragma omp for schedule(dynamic)
599+
#pragma omp for
600+
#endif
601+
for(int i=0; i< GlobalC::GridT.nbx; i++)
602+
{
603+
const int ibx = i*GlobalC::pw.bx;
604+
for(int j=0; j< GlobalC::GridT.nby; j++)
583605
{
584-
const int kbz = k*GlobalC::pw.bz-GlobalC::pw.nczp_start;
585-
const int grid_index = (k-GlobalC::GridT.nbzp_start) + j * GlobalC::GridT.nbzp + i * GlobalC::GridT.nby * GlobalC::GridT.nbzp;
586-
const int na_grid = GlobalC::GridT.how_many_atoms[ grid_index ];
587-
if(na_grid==0)continue;
588-
589-
//------------------------------------------------------------------
590-
// extract the local potentials.
591-
//------------------------------------------------------------------
592-
double *vldr3 = get_vldr3(vlocal, ncyz, ibx, jby, kbz);
593-
594-
//------------------------------------------------------
595-
// index of wave functions for each block
596-
//------------------------------------------------------
597-
int *block_iw = Gint_Tools::get_block_iw(na_grid, grid_index, this->max_size);
598-
599-
int* block_index = Gint_Tools::get_block_index(na_grid, grid_index);
600-
601-
//------------------------------------------------------
602-
// band size: number of columns of a band
603-
//------------------------------------------------------------------
604-
int* block_size = Gint_Tools::get_block_size(na_grid, grid_index);
605-
606-
Gint_Tools::Array_Pool<double> psir_vlbr3(GlobalC::pw.bxyz, LD_pool);
607-
Gint_Tools::Array_Pool<double> psir_ylm(GlobalC::pw.bxyz, LD_pool);
608-
609-
cal_psir_ylm_dphi(na_grid, grid_index, delta_r,
606+
const int jby = j*GlobalC::pw.by;
607+
for(int k= GlobalC::GridT.nbzp_start; k< GlobalC::GridT.nbzp_start+GlobalC::GridT.nbzp; k++)
608+
{
609+
const int kbz = k*GlobalC::pw.bz-GlobalC::pw.nczp_start;
610+
const int grid_index = (k-GlobalC::GridT.nbzp_start) + j * GlobalC::GridT.nbzp + i * GlobalC::GridT.nby * GlobalC::GridT.nbzp;
611+
const int na_grid = GlobalC::GridT.how_many_atoms[ grid_index ];
612+
if(na_grid==0) continue;
613+
614+
//------------------------------------------------------------------
615+
// extract the local potentials.
616+
//------------------------------------------------------------------
617+
double *vldr3 = get_vldr3(vlocal, ncyz, ibx, jby, kbz);
618+
619+
//------------------------------------------------------
620+
// index of wave functions for each block
621+
//------------------------------------------------------
622+
int *block_iw = Gint_Tools::get_block_iw(na_grid, grid_index, this->max_size);
623+
int* block_index = Gint_Tools::get_block_index(na_grid, grid_index);
624+
625+
//------------------------------------------------------
626+
// band size: number of columns of a band
627+
//------------------------------------------------------
628+
int* block_size = Gint_Tools::get_block_size(na_grid, grid_index);
629+
630+
Gint_Tools::Array_Pool<double> psir_vlbr3(GlobalC::pw.bxyz, LD_pool);
631+
Gint_Tools::Array_Pool<double> psir_ylm(GlobalC::pw.bxyz, LD_pool);
632+
633+
cal_psir_ylm_dphi(na_grid, grid_index, delta_r,
610634
block_index, block_size, cal_flag, psir_ylm.ptr_2D, dphix, dphiy, dphiz, drr);
611-
612-
cal_meshball_DGridV(na_grid, GlobalC::GridT.lgd, LD_pool, block_index, block_iw, block_size, cal_flag, vldr3,
613-
psir_ylm.ptr_2D, psir_vlbr3.ptr_2D, dphix, dphiy, dphiz,
614-
DGridV_x, DGridV_y, DGridV_z,
615-
DGridV_11, DGridV_12, DGridV_13,
616-
DGridV_22, DGridV_23, DGridV_33, drr);
617-
618-
free(vldr3); vldr3=nullptr;
619-
free(block_iw); block_iw=nullptr;
620-
free(block_index); block_index=nullptr;
621-
free(block_size); block_size=nullptr;
622-
}// k
623-
}// j
624-
}// i
625-
626-
//OUT(GlobalV::ofs_running,"DGridV was calculated");
627-
delete[] dphix;
628-
delete[] dphiy;
629-
delete[] dphiz;
630-
delete[] dphi_pool;
631-
for(int ib=0; ib<GlobalC::pw.bxyz; ++ib)
632-
{
633-
delete[] cal_flag[ib];
634-
}
635-
delete[] cal_flag;
636-
//OUT(GlobalV::ofs_running,"temp variables were deleted");
637635

638-
}//end if, replace goto line
639-
//ENDandRETURN:
636+
#ifdef _OPENMP
637+
cal_meshball_DGridV(na_grid, GlobalC::GridT.lgd, LD_pool, block_index, block_iw, block_size, cal_flag, vldr3,
638+
psir_ylm.ptr_2D, psir_vlbr3.ptr_2D, dphix, dphiy, dphiz,
639+
DGridV_x_thread, DGridV_y_thread, DGridV_z_thread,
640+
DGridV_11_thread, DGridV_12_thread, DGridV_13_thread,
641+
DGridV_22_thread, DGridV_23_thread, DGridV_33_thread, drr);
642+
#else
643+
cal_meshball_DGridV(na_grid, GlobalC::GridT.lgd, LD_pool, block_index, block_iw, block_size, cal_flag, vldr3,
644+
psir_ylm.ptr_2D, psir_vlbr3.ptr_2D, dphix, dphiy, dphiz,
645+
DGridV_x, DGridV_y, DGridV_z,
646+
DGridV_11, DGridV_12, DGridV_13,
647+
DGridV_22, DGridV_23, DGridV_33, drr);
648+
#endif
649+
650+
free(vldr3); vldr3=nullptr;
651+
free(block_iw); block_iw=nullptr;
652+
free(block_index); block_index=nullptr;
653+
free(block_size); block_size=nullptr;
654+
}// k
655+
}// j
656+
}// i
657+
658+
delete[] dphix;
659+
delete[] dphiy;
660+
delete[] dphiz;
661+
delete[] dphi_pool;
662+
for(int ib=0; ib<GlobalC::pw.bxyz; ++ib)
663+
{
664+
delete[] cal_flag[ib];
665+
}
666+
delete[] cal_flag;
667+
}
668+
#ifdef _OPENMP
669+
#pragma omp critical(cal_fvl)
670+
{
671+
for(int i=0; i<GlobalC::GridT.lgd; i++)
672+
{
673+
for(int j=0; j<GlobalC::GridT.lgd; j++)
674+
{
675+
DGridV_x[i][j] += DGridV_x_thread[i][j];
676+
DGridV_y[i][j] += DGridV_y_thread[i][j];
677+
DGridV_z[i][j] += DGridV_z_thread[i][j];
678+
679+
if(GlobalV::STRESS)
680+
{
681+
DGridV_11[i][j] += DGridV_11_thread[i][j];
682+
DGridV_12[i][j] += DGridV_12_thread[i][j];
683+
DGridV_13[i][j] += DGridV_13_thread[i][j];
684+
DGridV_22[i][j] += DGridV_22_thread[i][j];
685+
DGridV_23[i][j] += DGridV_23_thread[i][j];
686+
DGridV_33[i][j] += DGridV_33_thread[i][j];
687+
}
688+
}
689+
}
690+
691+
delete [] DGridV_x_thread;
692+
delete [] DGridV_y_thread;
693+
delete [] DGridV_z_thread;
694+
if(GlobalV::STRESS)
695+
{
696+
delete [] DGridV_11_thread;
697+
delete [] DGridV_12_thread;
698+
delete [] DGridV_13_thread;
699+
delete [] DGridV_22_thread;
700+
delete [] DGridV_23_thread;
701+
delete [] DGridV_33_thread;
702+
delete [] DGridV_stress_pool_thread;
703+
}
704+
delete [] DGridV_pool_thread;
705+
}
706+
#endif
707+
#ifdef _OPENMP
708+
}
709+
#endif
710+
640711
ModuleBase::timer::tick("Gint_Gamma","gamma_force");
641712
#ifdef __MPI
642713
ModuleBase::timer::tick("Gint_Gamma","gamma_force_wait");
643-
MPI_Barrier(MPI_COMM_WORLD);
714+
MPI_Barrier(MPI_COMM_WORLD);
644715
ModuleBase::timer::tick("Gint_Gamma","gamma_force_wait");
645716
#endif
646717
ModuleBase::timer::tick("Gint_Gamma","gamma_force2");
647718

648-
649719
//OUT(GlobalV::ofs_running,"Start reduce DGridV");
650720

651721
double* tmpx = new double[GlobalV::NLOCAL];

0 commit comments

Comments
 (0)