Skip to content

Commit 34a92ad

Browse files
authored
Merge branch 'develop' into init
2 parents bcf89ff + 79e054e commit 34a92ad

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+2138
-2323
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2060,7 +2060,7 @@ Warning: this function is not robust enough for the current version. Please try
20602060
- **Type**: int
20612061
- **Availability**: numerical atomic orbital basis
20622062
- **Description**: Include V_delta label for DeePKS training. When `deepks_out_labels` is true and `deepks_v_delta` > 0, ABACUS will output h_base.npy, v_delta.npy and h_tot.npy(h_tot=h_base+v_delta).
2063-
Meanwhile, when `deepks_v_delta` equals 1, ABACUS will also output v_delta_precalc.npy, which is used to calculate V_delta during DeePKS training. However, when the number of atoms grows, the size of v_delta_precalc.npy will be very large. In this case, it's recommended to set `deepks_v_delta` as 2, and ABACUS will output psialpha.npy and grad_evdm.npy but not v_delta_precalc.npy. These two files are small and can be used to calculate v_delta_precalc in the procedure of training DeePKS.
2063+
Meanwhile, when `deepks_v_delta` equals 1, ABACUS will also output v_delta_precalc.npy, which is used to calculate V_delta during DeePKS training. However, when the number of atoms grows, the size of v_delta_precalc.npy will be very large. In this case, it's recommended to set `deepks_v_delta` as 2, and ABACUS will output phialpha.npy and grad_evdm.npy but not v_delta_precalc.npy. These two files are small and can be used to calculate v_delta_precalc in the procedure of training DeePKS.
20642064
- **Default**: 0
20652065

20662066
### deepks_out_unittest

source/Makefile.Objects

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,12 @@ OBJS_CELL=atom_pseudo.o\
189189
check_atomic_stru.o\
190190

191191
OBJS_DEEPKS=LCAO_deepks.o\
192-
deepks_fgamma.o\
193-
deepks_fk.o\
192+
deepks_force.o\
194193
LCAO_deepks_odelta.o\
195194
LCAO_deepks_io.o\
196195
LCAO_deepks_mpi.o\
197196
LCAO_deepks_pdm.o\
198-
LCAO_deepks_psialpha.o\
197+
LCAO_deepks_phialpha.o\
199198
LCAO_deepks_torch.o\
200199
LCAO_deepks_vdelta.o\
201200
deepks_hmat.o\

source/module_cell/read_atoms.cpp

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -676,45 +676,44 @@ bool UnitCell::read_atom_positions(std::ifstream &ifpos, std::ofstream &ofs_runn
676676
std::string mags;
677677
//cout<<"mag"<<atoms[it].mag[ia]<<"angle1"<<atoms[it].angle1[ia]<<"angle2"<<atoms[it].angle2[ia]<<'\n';
678678

679-
if(PARAM.inp.nspin==4)
680-
{
681-
if(PARAM.inp.noncolin)
679+
// ----------------------------------------------------------------------------
680+
// recalcualte mag and m_loc_ from read in angle1, angle2 and mag or mx, my, mz
681+
if(input_angle_mag)
682+
{// angle1 or angle2 are given, calculate mx, my, mz from angle1 and angle2 and mag
683+
atoms[it].m_loc_[ia].z = atoms[it].mag[ia] *
684+
cos(atoms[it].angle1[ia]);
685+
if(std::abs(sin(atoms[it].angle1[ia])) > 1e-10 )
682686
{
683-
if(input_angle_mag)
684-
{
685-
atoms[it].m_loc_[ia].z = atoms[it].mag[ia] *
686-
cos(atoms[it].angle1[ia]);
687-
if(std::abs(sin(atoms[it].angle1[ia])) > 1e-10 )
688-
{
689-
atoms[it].m_loc_[ia].x = atoms[it].mag[ia] *
690-
sin(atoms[it].angle1[ia]) * cos(atoms[it].angle2[ia]);
691-
atoms[it].m_loc_[ia].y = atoms[it].mag[ia] *
692-
sin(atoms[it].angle1[ia]) * sin(atoms[it].angle2[ia]);
693-
}
694-
}
695-
else if (input_vec_mag)
696-
{
697-
double mxy=sqrt(pow(atoms[it].m_loc_[ia].x,2)+pow(atoms[it].m_loc_[ia].y,2));
698-
atoms[it].angle1[ia]=atan2(mxy,atoms[it].m_loc_[ia].z);
699-
if(mxy>1e-8)
700-
{
701-
atoms[it].angle2[ia]=atan2(atoms[it].m_loc_[ia].y,atoms[it].m_loc_[ia].x);
702-
}
703-
}
704-
else
705-
{
706-
atoms[it].m_loc_[ia].x = 0;
707-
atoms[it].m_loc_[ia].y = 0;
708-
atoms[it].m_loc_[ia].z = atoms[it].mag[ia];
709-
}
687+
atoms[it].m_loc_[ia].x = atoms[it].mag[ia] *
688+
sin(atoms[it].angle1[ia]) * cos(atoms[it].angle2[ia]);
689+
atoms[it].m_loc_[ia].y = atoms[it].mag[ia] *
690+
sin(atoms[it].angle1[ia]) * sin(atoms[it].angle2[ia]);
710691
}
711-
else
692+
}
693+
else if (input_vec_mag)
694+
{// mx, my, mz are given, calculate angle1 and angle2 from mx, my, mz
695+
double mxy=sqrt(pow(atoms[it].m_loc_[ia].x,2)+pow(atoms[it].m_loc_[ia].y,2));
696+
atoms[it].angle1[ia]=atan2(mxy,atoms[it].m_loc_[ia].z);
697+
if(mxy>1e-8)
712698
{
699+
atoms[it].angle2[ia]=atan2(atoms[it].m_loc_[ia].y,atoms[it].m_loc_[ia].x);
700+
}
701+
}
702+
else// only one mag is given, assume it is z
703+
{
704+
atoms[it].m_loc_[ia].x = 0;
705+
atoms[it].m_loc_[ia].y = 0;
706+
atoms[it].m_loc_[ia].z = atoms[it].mag[ia];
707+
}
708+
709+
if(PARAM.inp.nspin==4)
710+
{
711+
if(!PARAM.inp.noncolin)
712+
{
713+
//collinear case with nspin = 4, only z component is used
713714
atoms[it].m_loc_[ia].x = 0;
714715
atoms[it].m_loc_[ia].y = 0;
715-
atoms[it].m_loc_[ia].z = atoms[it].mag[ia];
716716
}
717-
718717
//print only ia==0 && mag>0 to avoid too much output
719718
//print when ia!=0 && mag[ia] != mag[0] to avoid too much output
720719
// 'A || (!A && B)' is equivalent to 'A || B',so the following
@@ -735,8 +734,8 @@ bool UnitCell::read_atom_positions(std::ifstream &ifpos, std::ofstream &ofs_runn
735734
ModuleBase::GlobalFunc::ZEROS(magnet.ux_ ,3);
736735
}
737736
else if(PARAM.inp.nspin==2)
738-
{
739-
atoms[it].m_loc_[ia].x = atoms[it].mag[ia];
737+
{// collinear case with nspin = 2, only z component is used
738+
atoms[it].mag[ia] = atoms[it].m_loc_[ia].z;
740739
//print only ia==0 && mag>0 to avoid too much output
741740
//print when ia!=0 && mag[ia] != mag[0] to avoid too much output
742741
if(ia==0 || (atoms[it].mag[ia] != atoms[it].mag[0]))
@@ -751,6 +750,8 @@ bool UnitCell::read_atom_positions(std::ifstream &ifpos, std::ofstream &ofs_runn
751750
ModuleBase::GlobalFunc::OUT(ofs_running, ss.str(),atoms[it].mag[ia]);
752751
}
753752
}
753+
// end of calculating initial magnetization of each atom
754+
// ----------------------------------------------------------------------------
754755

755756
if(Coordinate=="Direct")
756757
{

source/module_elecstate/elecstate_pw.cpp

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,36 +31,55 @@ ElecStatePW<T, Device>::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in,
3131
template<typename T, typename Device>
3232
ElecStatePW<T, Device>::~ElecStatePW()
3333
{
34-
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
34+
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
3535
{
3636
delmem_var_op()(this->ctx, this->rho_data);
37+
delete[] this->rho;
38+
39+
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
40+
{
41+
delmem_complex_op()(this->ctx, this->rhog_data);
42+
delete[] this->rhog;
43+
}
3744
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
3845
{
3946
delmem_var_op()(this->ctx, this->kin_r_data);
47+
delete[] this->kin_r;
4048
}
4149
}
42-
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
43-
delete[] this->rho;
44-
delete[] this->kin_r;
50+
if (PARAM.globalv.use_uspp)
51+
{
52+
delmem_var_op()(this->ctx, this->becsum);
4553
}
46-
delmem_var_op()(this->ctx, becsum);
4754
delmem_complex_op()(this->ctx, this->wfcr);
4855
delmem_complex_op()(this->ctx, this->wfcr_another_spin);
4956
}
5057

5158
template<typename T, typename Device>
5259
void ElecStatePW<T, Device>::init_rho_data()
5360
{
54-
if(this->init_rho) {
61+
if (this->init_rho)
62+
{
5563
return;
5664
}
57-
58-
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
65+
66+
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
67+
{
5968
this->rho = new Real*[this->charge->nspin];
6069
resmem_var_op()(this->ctx, this->rho_data, this->charge->nspin * this->charge->nrxx);
61-
for (int ii = 0; ii < this->charge->nspin; ii++) {
70+
for (int ii = 0; ii < this->charge->nspin; ii++)
71+
{
6272
this->rho[ii] = this->rho_data + ii * this->charge->nrxx;
6373
}
74+
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
75+
{
76+
this->rhog = new T*[this->charge->nspin];
77+
resmem_complex_op()(this->ctx, this->rhog_data, this->charge->nspin * this->charge->rhopw->npw);
78+
for (int ii = 0; ii < this->charge->nspin; ii++)
79+
{
80+
this->rhog[ii] = this->rhog_data + ii * this->charge->rhopw->npw;
81+
}
82+
}
6483
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
6584
{
6685
this->kin_r = new Real*[this->charge->nspin];
@@ -70,8 +89,13 @@ void ElecStatePW<T, Device>::init_rho_data()
7089
}
7190
}
7291
}
73-
else {
92+
else
93+
{
7494
this->rho = reinterpret_cast<Real **>(this->charge->rho);
95+
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
96+
{
97+
this->rhog = reinterpret_cast<T**>(this->charge->rhog);
98+
}
7599
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
76100
{
77101
this->kin_r = reinterpret_cast<Real **>(this->charge->kin_r);
@@ -100,19 +124,24 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
100124
// ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
101125
setmem_var_op()(this->ctx, this->kin_r[is], 0, this->charge->nrxx);
102126
}
103-
}
127+
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
128+
{
129+
setmem_complex_op()(this->ctx, this->rhog[is], 0, this->charge->rhopw->npw);
130+
}
131+
}
104132

105133
for (int ik = 0; ik < psi.get_nk(); ++ik)
106134
{
107135
psi.fix_k(ik);
108136
this->updateRhoK(psi);
109137
}
110-
if (PARAM.globalv.use_uspp)
138+
139+
this->add_usrho(psi);
140+
141+
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
111142
{
112-
this->add_usrho(psi);
113-
}
114-
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
115-
for (int ii = 0; ii < PARAM.inp.nspin; ii++) {
143+
for (int ii = 0; ii < PARAM.inp.nspin; ii++)
144+
{
116145
castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->rho[ii], this->rho[ii], this->charge->nrxx);
117146
if (get_xc_func_type() == 3)
118147
{
@@ -397,32 +426,39 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
397426
template <typename T, typename Device>
398427
void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
399428
{
400-
this->cal_becsum(psi);
429+
if (PARAM.globalv.use_uspp)
430+
{
431+
this->cal_becsum(psi);
432+
}
401433

402434
// transform soft charge to recip space using smooth grids
403-
T* rhog = nullptr;
404-
resmem_complex_op()(this->ctx, rhog, this->charge->rhopw->npw * PARAM.inp.nspin, "ElecState<PW>::rhog");
405-
setmem_complex_op()(this->ctx, rhog, 0, this->charge->rhopw->npw * PARAM.inp.nspin);
406-
for (int is = 0; is < PARAM.inp.nspin; is++)
435+
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
407436
{
408-
this->rhopw_smooth->real2recip(this->rho[is], &rhog[is * this->charge->rhopw->npw]);
437+
for (int is = 0; is < PARAM.inp.nspin; is++)
438+
{
439+
this->rhopw_smooth->real2recip(this->rho[is], this->rhog[is]);
440+
}
409441
}
410442

411443
// \sum_lm Q_lm(r) \sum_i <psi_i|beta_l><beta_m|psi_i> w_i
412444
// add to the charge density in reciprocal space the part which is due to the US augmentation.
413-
this->addusdens_g(becsum, rhog);
445+
if (PARAM.globalv.use_uspp)
446+
{
447+
this->addusdens_g(becsum, rhog);
448+
}
414449

415450
// transform back to real space using dense grids
416-
for (int is = 0; is < PARAM.inp.nspin; is++)
451+
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
417452
{
418-
this->charge->rhopw->recip2real(&rhog[is * this->charge->rhopw->npw], this->rho[is]);
453+
for (int is = 0; is < PARAM.inp.nspin; is++)
454+
{
455+
this->charge->rhopw->recip2real(this->rhog[is], this->rho[is]);
456+
}
419457
}
420-
421-
delmem_complex_op()(this->ctx, rhog);
422458
}
423459

424460
template <typename T, typename Device>
425-
void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T* rhog)
461+
void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T** rhog)
426462
{
427463
const T one{1, 0};
428464
const T zero{0, 0};
@@ -506,7 +542,7 @@ void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T* rhog)
506542
this->ppcell->radial_fft_q(this->ctx, npw, ih, jh, it, qmod, ylmk0, qgm);
507543
for (int ig = 0; ig < npw; ig++)
508544
{
509-
rhog[is * npw + ig] += qgm[ig] * aux2[ijh * npw + ig];
545+
rhog[is][ig] += qgm[ig] * aux2[ijh * npw + ig];
510546
}
511547
ijh++;
512548
}

source/module_elecstate/elecstate_pw.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ class ElecStatePW : public ElecState
4242

4343
//! init rho_data and kin_r_data
4444
void init_rho_data();
45-
Real** rho = nullptr;
46-
Real** kin_r = nullptr; //[Device] [spin][nrxx] rho and kin_r
45+
Real** rho = nullptr; // [Device] [spin][nrxx] rho
46+
T** rhog = nullptr; // [Device] [spin][nrxx] rhog
47+
Real** kin_r = nullptr; // [Device] [spin][nrxx] kin_r
4748

4849
protected:
4950

@@ -70,15 +71,16 @@ class ElecStatePW : public ElecState
7071

7172
//! Non-local pseudopotentials
7273
//! \sum_lm Q_lm(r) \sum_i <psi_i|beta_l><beta_m|psi_i> w_i
73-
void addusdens_g(const Real* becsum, T* rhog);
74+
void addusdens_g(const Real* becsum, T** rhog);
7475

7576
Device * ctx = {};
7677

7778
bool init_rho = false;
7879

7980
mutable T* vkb = nullptr;
8081

81-
Real* rho_data = nullptr;
82+
Real* rho_data = nullptr;
83+
T* rhog_data = nullptr;
8284
Real* kin_r_data = nullptr;
8385
T* wfcr = nullptr;
8486
T* wfcr_another_spin = nullptr;

source/module_esolver/lcao_before_scf.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,19 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
205205
}
206206

207207
#ifdef __DEEPKS
208-
// for each ionic step, the overlap <psi|alpha> must be rebuilt
208+
// for each ionic step, the overlap <phi|alpha> must be rebuilt
209209
// since it depends on ionic positions
210210
if (PARAM.globalv.deepks_setorb)
211211
{
212212
const Parallel_Orbitals* pv = &this->pv;
213-
// build and save <psi(0)|alpha(R)> at beginning
214-
GlobalC::ld.build_psialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
213+
// allocate <phi(0)|alpha(R)>, phialpha is different every ion step, so it is allocated here
214+
GlobalC::ld.allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
215+
// build and save <phi(0)|alpha(R)> at beginning
216+
GlobalC::ld.build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
215217

216218
if (PARAM.inp.deepks_out_unittest)
217219
{
218-
GlobalC::ld.check_psialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
220+
GlobalC::ld.check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
219221
}
220222
}
221223
#endif

source/module_esolver/lcao_others.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,17 +211,19 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
211211
}
212212

213213
#ifdef __DEEPKS
214-
// for each ionic step, the overlap <psi|alpha> must be rebuilt
214+
// for each ionic step, the overlap <phi|alpha> must be rebuilt
215215
// since it depends on ionic positions
216216
if (PARAM.globalv.deepks_setorb)
217217
{
218218
const Parallel_Orbitals* pv = &this->pv;
219-
// build and save <psi(0)|alpha(R)> at beginning
220-
GlobalC::ld.build_psialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
219+
// allocate <phi(0)|alpha(R)>, phialpha is different every ion step, so it is allocated here
220+
GlobalC::ld.allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
221+
// build and save <phi(0)|alpha(R)> at beginning
222+
GlobalC::ld.build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
221223

222224
if (PARAM.inp.deepks_out_unittest)
223225
{
224-
GlobalC::ld.check_psialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
226+
GlobalC::ld.check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
225227
}
226228
}
227229
#endif

0 commit comments

Comments
 (0)