Skip to content

Commit 5d34fb1

Browse files
committed
add npart_sto to avoid using too much memory
1 parent 3bf4613 commit 5d34fb1

File tree

8 files changed

+48
-18
lines changed

8 files changed

+48
-18
lines changed

docs/input-main.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,12 @@ This part of variables are used to control the parameters of stochastic DFT (SDF
642642
- **Description**: Frequency (once each initsto_freq steps) to generate new stochastic orbitals when running md.
643643
- **Default**:1000
644644
645+
#### npart_sto
646+
647+
- **Type**: Integer
648+
- **Description**: Make memory cost to 1/npart_sto times of previous one when running post process of SDFT like DOS with method_sto = 2.
649+
- **Default**:1
650+
645651
### Geometry relaxation
646652
647653
This part of variables are used to control the geometry relaxation.

source/input.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ void Input::Default(void)
145145
kpar = 1;
146146
initsto_freq = 1000;
147147
method_sto = 2;
148+
npart_sto = 1;
148149
cal_cond = false;
149150
dos_nche = 100;
150151
cond_nche = 20;
@@ -595,6 +596,10 @@ bool Input::Read(const std::string &fn)
595596
{
596597
read_value(ifs, method_sto);
597598
}
599+
else if (strcmp("npart_sto", word) == 0)
600+
{
601+
read_value(ifs, npart_sto);
602+
}
598603
else if (strcmp("cal_cond", word) == 0)
599604
{
600605
read_value(ifs, cal_cond);
@@ -1978,6 +1983,7 @@ void Input::Bcast()
19781983
Parallel_Common::bcast_double(emin_sto);
19791984
Parallel_Common::bcast_int(initsto_freq);
19801985
Parallel_Common::bcast_int(method_sto);
1986+
Parallel_Common::bcast_int(npart_sto);
19811987
Parallel_Common::bcast_bool(cal_cond);
19821988
Parallel_Common::bcast_int(cond_nche);
19831989
Parallel_Common::bcast_double(cond_dw);

source/input.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Input
7272
int bndpar; //parallel for stochastic/deterministic bands
7373
int initsto_freq; //frequency to init stochastic orbitals when running md
7474
int method_sto; //different methods for sdft, 1: slow, less memory 2: fast, more memory
75+
int npart_sto; //for method_sto = 2, reduce memory
7576
bool cal_cond; //calculate electronic conductivities
7677
int cond_nche; //orders of Chebyshev expansions for conductivities
7778
double cond_dw; //d\omega for conductivities

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,21 @@ void ESolver_SDFT_PW::postprocess()
166166
if(INPUT.out_dos)
167167
{
168168
double emax, emin;
169-
if(INPUT.dos_setemax) emax = INPUT.dos_emax_ev;
170-
if(INPUT.dos_setemin) emin = INPUT.dos_emin_ev;
169+
if(INPUT.dos_setemax)
170+
emax = INPUT.dos_emax_ev;
171+
else
172+
emax = ((hsolver::HSolverPW_SDFT*)phsol)->stoiter.stohchi.Emax*ModuleBase::Ry_to_eV;
173+
if(INPUT.dos_setemin)
174+
emin = INPUT.dos_emin_ev;
175+
else
176+
emin = ((hsolver::HSolverPW_SDFT*)phsol)->stoiter.stohchi.Emin*ModuleBase::Ry_to_eV;
171177
if(!INPUT.dos_setemax && !INPUT.dos_setemin)
172178
{
173-
emax = ((hsolver::HSolverPW_SDFT*)phsol)->stoiter.stohchi.Emax;
174-
emin = ((hsolver::HSolverPW_SDFT*)phsol)->stoiter.stohchi.Emin;
175179
double delta=(emax-emin)*INPUT.dos_scale;
176180
emax=emax+delta/2.0;
177181
emin=emin-delta/2.0;
178182
}
179-
this->caldos(INPUT.dos_nche, INPUT.b_coef, emin, emax, INPUT.dos_edelta_ev );
183+
this->caldos(INPUT.dos_nche, INPUT.b_coef, emin, emax, INPUT.dos_edelta_ev, INPUT.npart_sto );
180184
}
181185
}
182186

source/module_esolver/esolver_sdft_pw.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ESolver_SDFT_PW: public ESolver_KS_PW
3333
const double dw_in, const int times);
3434
//calculate DOS
3535
void caldos(const int nche_dos, const double sigmain,
36-
const double emin, const double emax, const double de);
36+
const double emin, const double emax, const double de, const int npart);
3737

3838
private:
3939
int nche_sto; //norder of Chebyshev

source/module_esolver/esolver_sdft_pw_tool.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
454454
ModuleBase::timer::tick(this->classname,"sKG");
455455
}
456456

457-
void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const double emin, const double emax, const double de)
457+
void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const double emin, const double emax, const double de, const int npart)
458458
{
459459
cout<<"========================="<<endl;
460460
cout<<"###Calculating Dos....###"<<endl;
@@ -476,7 +476,8 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
476476
{
477477
spolyv = new double [nche_dos*nche_dos];
478478
ModuleBase::GlobalFunc::ZEROS(spolyv, nche_dos*nche_dos);
479-
allorderchi = new std::complex<double> [this->stowf.nchip_max * npwx * nche_dos];
479+
int nchip_new = ceil((double)this->stowf.nchip_max / npart);
480+
allorderchi = new std::complex<double> [nchip_new * npwx * nche_dos];
480481
}
481482
cout<<"1. TracepolyA:"<<endl;
482483
for (int ik = 0;ik < nk;ik++)
@@ -490,7 +491,7 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
490491
const int npw = GlobalC::kv.ngk[ik];
491492
const int nchipk = this->stowf.nchip[ik];
492493

493-
complex<double> * pchi;
494+
std::complex<double> * pchi;
494495
if(GlobalV::NBANDS > 0)
495496
pchi = stowf.chiortho[ik].c;
496497
else
@@ -505,17 +506,28 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
505506
}
506507
else
507508
{
508-
ModuleBase::GlobalFunc::ZEROS(allorderchi, this->stowf.nchip_max * npwx * nche_dos);
509-
che.calpolyvec_complex(&stohchi, &Stochastic_hchi::hchi_norm, pchi, allorderchi, npw, npwx, nchipk);
510-
double* vec_all= (double *) allorderchi;
509+
int N = nche_dos;
510+
double kweight = GlobalC::kv.wk[ik] / 2;
511511
char trans = 'T';
512512
char normal = 'N';
513513
double one = 1;
514-
int LDA = npwx * nchipk * 2;
515-
int M = npwx * nchipk * 2;
516-
int N = nche_dos;
517-
double kweight = GlobalC::kv.wk[ik] / 2;
518-
dgemm_(&trans,&normal, &N,&N,&M,&kweight,vec_all,&LDA,vec_all,&LDA,&one,spolyv,&N);
514+
for(int ipart = 0 ; ipart < npart ; ++ipart)
515+
{
516+
int nchipk_new = nchipk / npart;
517+
int start_nchipk = ipart * nchipk_new + nchipk % npart;
518+
if(ipart < nchipk % npart)
519+
{
520+
nchipk_new++;
521+
start_nchipk = ipart * nchipk_new;
522+
}
523+
ModuleBase::GlobalFunc::ZEROS(allorderchi, nchipk_new * npwx * nche_dos);
524+
std::complex<double> *tmpchi = pchi + start_nchipk * npwx;
525+
che.calpolyvec_complex(&stohchi, &Stochastic_hchi::hchi_norm, tmpchi, allorderchi, npw, npwx, nchipk_new);
526+
double* vec_all= (double *) allorderchi;
527+
int LDA = npwx * nchipk_new * 2;
528+
int M = npwx * nchipk_new * 2;
529+
dgemm_(&trans,&normal, &N,&N,&M,&kweight,vec_all,&LDA,vec_all,&LDA,&one,spolyv,&N);
530+
}
519531
}
520532
}
521533
if(stoiter.method == 2) delete[] allorderchi;

source/src_io/write_input.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ void Input::Print(const std::string &fn) const
111111

112112
ofs << "\n#Parameters (3.Stochastic DFT)" << std::endl;
113113
ModuleBase::GlobalFunc::OUTP(ofs, "method_sto", method_sto, "1: slow and save memory, 2: fast and waste memory");
114+
ModuleBase::GlobalFunc::OUTP(ofs, "npart_sto", npart_sto, "Reduce memory when calculating Stochastic DOS");
114115
ModuleBase::GlobalFunc::OUTP(ofs, "nbands_sto", nbands_sto, "number of stochstic orbitals");
115116
ModuleBase::GlobalFunc::OUTP(ofs, "nche_sto", nche_sto, "Chebyshev expansion orders");
116117
ModuleBase::GlobalFunc::OUTP(ofs, "emin_sto", emin_sto, "trial energy to guess the lower bound of eigen energies of the Hamitonian operator");

tests/integrate/186_PW_SDOS_10D10S/INPUT

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ dos_emax_ev 100
3939
dos_edelta_ev 0.1
4040
dos_sigma 4
4141
dos_nche 240
42-
42+
npart_sto 2

0 commit comments

Comments
 (0)