Skip to content

Commit 6f45435

Browse files
committed
Merge branch 'develop' of https://github.com/deepmodeling/abacus-develop into hotfix
2 parents 85b1951 + e25db6e commit 6f45435

File tree

102 files changed

+4634
-3419
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+4634
-3419
lines changed

source/Makefile.Objects

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,21 +192,21 @@ OBJS_CELL=atom_pseudo.o\
192192

193193
OBJS_DEEPKS=LCAO_deepks.o\
194194
deepks_force.o\
195+
deepks_descriptor.o\
195196
deepks_orbital.o\
197+
deepks_orbpre.o\
198+
deepks_vdpre.o\
199+
deepks_hmat.o\
196200
LCAO_deepks_io.o\
197-
LCAO_deepks_mpi.o\
198201
LCAO_deepks_pdm.o\
199202
LCAO_deepks_phialpha.o\
200203
LCAO_deepks_torch.o\
201204
LCAO_deepks_vdelta.o\
202-
deepks_hmat.o\
203205
LCAO_deepks_interface.o\
204-
deepks_orbpre.o\
205206
cal_gdmx.o\
207+
cal_gdmepsl.o\
206208
cal_gedm.o\
207209
cal_gvx.o\
208-
cal_descriptor.o\
209-
v_delta_precalc.o\
210210

211211

212212
OBJS_ELECSTAT=elecstate.o\

source/module_cell/atom_spec.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Atom
2222
std::vector<bool> iw2_new;
2323
int nw = 0; // number of local orbitals (l,n,m) of this type
2424

25-
void set_index(void);
25+
void set_index();
2626

2727
int type = 0; // Index of atom type
2828
int na = 0; // Number of atoms in this type.
@@ -34,8 +34,7 @@ class Atom
3434

3535
std::string label = "\0"; // atomic symbol
3636
std::vector<ModuleBase::Vector3<double>> tau; // Cartesian coordinates of each atom in this type.
37-
std::vector<ModuleBase::Vector3<double>>
38-
dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
37+
std::vector<ModuleBase::Vector3<double>> dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
3938
std::vector<ModuleBase::Vector3<double>> taud; // Direct coordinates of each atom in this type.
4039
std::vector<ModuleBase::Vector3<double>> vel; // velocities of each atom in this type.
4140
std::vector<ModuleBase::Vector3<double>> force; // force acting on each atom in this type.
@@ -54,8 +53,8 @@ class Atom
5453
void print_Atom(std::ofstream& ofs);
5554
void update_force(ModuleBase::matrix& fcs);
5655
#ifdef __MPI
57-
void bcast_atom(void);
58-
void bcast_atom2(void);
56+
void bcast_atom();
57+
void bcast_atom2();
5958
#endif
6059
};
6160

source/module_cell/test/support/mock_unitcell.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ bool UnitCell::read_atom_positions(std::ifstream& ifpos,
3333
std::ofstream& ofs_warning) {
3434
return true;
3535
}
36-
void UnitCell::update_pos_taud(double* posd_in) {}
37-
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {}
38-
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {}
39-
void UnitCell::bcast_atoms_tau() {}
36+
4037
bool UnitCell::judge_big_cell() const { return true; }
4138
void UnitCell::update_stress(ModuleBase::matrix& scs) {}
4239
void UnitCell::update_force(ModuleBase::matrix& fcs) {}

source/module_cell/test/unitcell_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ TEST_F(UcellTest, UpdateVel)
10211021
{
10221022
vel_in[iat].set(iat * 0.1, iat * 0.1, iat * 0.1);
10231023
}
1024-
ucell->update_vel(vel_in);
1024+
unitcell::update_vel(vel_in,ucell->ntype,ucell->nat,ucell->atoms);
10251025
for (int iat = 0; iat < ucell->nat; ++iat)
10261026
{
10271027
EXPECT_DOUBLE_EQ(vel_in[iat].x, 0.1 * iat);

source/module_cell/test/unitcell_test_para.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ TEST_F(UcellTest, UpdatePosTau)
153153
}
154154
delete[] pos_in;
155155
}
156-
TEST_F(UcellTest, UpdatePosTaud)
156+
TEST_F(UcellTest, UpdatePosTaud_pointer)
157157
{
158158
double* pos_in = new double[ucell->nat * 3];
159159
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
@@ -167,7 +167,8 @@ TEST_F(UcellTest, UpdatePosTaud)
167167
ucell->iat2iait(iat, &ia, &it);
168168
tmp[iat] = ucell->atoms[it].taud[ia];
169169
}
170-
ucell->update_pos_taud(pos_in);
170+
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
171+
ucell->nat,ucell->atoms);
171172
for (int iat = 0; iat < ucell->nat; ++iat)
172173
{
173174
int it, ia;
@@ -180,6 +181,37 @@ TEST_F(UcellTest, UpdatePosTaud)
180181
delete[] pos_in;
181182
}
182183

184+
//test update_pos_taud with ModuleBase::Vector3<double> version
185+
TEST_F(UcellTest, UpdatePosTaud_Vector3)
186+
{
187+
ModuleBase::Vector3<double>* pos_in = new ModuleBase::Vector3<double>[ucell->nat];
188+
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
189+
ucell->set_iat2itia();
190+
for (int iat = 0; iat < ucell->nat; ++iat)
191+
{
192+
for (int ik = 0; ik < 3; ++ik)
193+
{
194+
pos_in[iat][ik] = 0.01;
195+
}
196+
int it=0;
197+
int ia=0;
198+
ucell->iat2iait(iat, &ia, &it);
199+
tmp[iat] = ucell->atoms[it].taud[ia];
200+
}
201+
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
202+
ucell->nat,ucell->atoms);
203+
for (int iat = 0; iat < ucell->nat; ++iat)
204+
{
205+
int it, ia;
206+
ucell->iat2iait(iat, &ia, &it);
207+
for (int ik = 0; ik < 3; ++ik)
208+
{
209+
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia][ik], tmp[iat][ik] + 0.01);
210+
}
211+
}
212+
delete[] tmp;
213+
delete[] pos_in;
214+
}
183215
TEST_F(UcellTest, ReadPseudo)
184216
{
185217
PARAM.input.pseudo_dir = pp_dir;

source/module_cell/unitcell.cpp

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -314,65 +314,6 @@ std::vector<ModuleBase::Vector3<int>> UnitCell::get_constrain() const
314314
return constrain;
315315
}
316316

317-
318-
319-
void UnitCell::update_pos_taud(double* posd_in) {
320-
int iat = 0;
321-
for (int it = 0; it < this->ntype; it++) {
322-
Atom* atom = &this->atoms[it];
323-
for (int ia = 0; ia < atom->na; ia++) {
324-
for (int ik = 0; ik < 3; ++ik) {
325-
atom->taud[ia][ik] += posd_in[3 * iat + ik];
326-
atom->dis[ia][ik] = posd_in[3 * iat + ik];
327-
}
328-
iat++;
329-
}
330-
}
331-
assert(iat == this->nat);
332-
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
333-
this->bcast_atoms_tau();
334-
}
335-
336-
// posd_in is atomic displacements here liuyu 2023-03-22
337-
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {
338-
int iat = 0;
339-
for (int it = 0; it < this->ntype; it++) {
340-
Atom* atom = &this->atoms[it];
341-
for (int ia = 0; ia < atom->na; ia++) {
342-
for (int ik = 0; ik < 3; ++ik) {
343-
atom->taud[ia][ik] += posd_in[iat][ik];
344-
atom->dis[ia][ik] = posd_in[iat][ik];
345-
}
346-
iat++;
347-
}
348-
}
349-
assert(iat == this->nat);
350-
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
351-
this->bcast_atoms_tau();
352-
}
353-
354-
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {
355-
int iat = 0;
356-
for (int it = 0; it < this->ntype; ++it) {
357-
Atom* atom = &this->atoms[it];
358-
for (int ia = 0; ia < atom->na; ++ia) {
359-
this->atoms[it].vel[ia] = vel_in[iat];
360-
++iat;
361-
}
362-
}
363-
assert(iat == this->nat);
364-
}
365-
366-
367-
void UnitCell::bcast_atoms_tau() {
368-
#ifdef __MPI
369-
MPI_Barrier(MPI_COMM_WORLD);
370-
for (int i = 0; i < ntype; i++) {
371-
atoms[i].bcast_atom(); // bcast tau array
372-
}
373-
#endif
374-
}
375-
376317
//==============================================================
377318
// Calculate various lattice related quantities for given latvec
378319
//==============================================================

source/module_cell/unitcell.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,6 @@ class UnitCell {
200200
void print_cell(std::ofstream& ofs) const;
201201
void print_cell_xyz(const std::string& fn) const;
202202

203-
void update_pos_taud(const ModuleBase::Vector3<double>* posd_in);
204-
void update_pos_taud(double* posd_in);
205-
void update_vel(const ModuleBase::Vector3<double>* vel_in);
206-
void bcast_atoms_tau();
207203
bool judge_big_cell() const;
208204

209205
void update_stress(ModuleBase::matrix& scs); // updates stress

source/module_cell/update_cell.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,75 @@ void update_pos_tau(const Lattice& lat,
374374
bcast_atoms_tau(atoms, ntype);
375375
}
376376

377+
void update_pos_taud(const Lattice& lat,
378+
const double* posd_in,
379+
const int ntype,
380+
const int nat,
381+
Atom* atoms)
382+
{
383+
int iat = 0;
384+
for (int it = 0; it < ntype; it++)
385+
{
386+
Atom* atom = &atoms[it];
387+
for (int ia = 0; ia < atom->na; ia++)
388+
{
389+
for (int ik = 0; ik < 3; ++ik)
390+
{
391+
atom->taud[ia][ik] += posd_in[3 * iat + ik];
392+
atom->dis[ia][ik] = posd_in[3 * iat + ik];
393+
}
394+
iat++;
395+
}
396+
}
397+
assert(iat == nat);
398+
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
399+
bcast_atoms_tau(atoms, ntype);
400+
}
401+
402+
// posd_in is atomic displacements here liuyu 2023-03-22
403+
void update_pos_taud(const Lattice& lat,
404+
const ModuleBase::Vector3<double>* posd_in,
405+
const int ntype,
406+
const int nat,
407+
Atom* atoms)
408+
{
409+
int iat = 0;
410+
for (int it = 0; it < ntype; it++)
411+
{
412+
Atom* atom = &atoms[it];
413+
for (int ia = 0; ia < atom->na; ia++)
414+
{
415+
for (int ik = 0; ik < 3; ++ik)
416+
{
417+
atom->taud[ia][ik] += posd_in[iat][ik];
418+
atom->dis[ia][ik] = posd_in[iat][ik];
419+
}
420+
iat++;
421+
}
422+
}
423+
assert(iat == nat);
424+
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
425+
bcast_atoms_tau(atoms, ntype);
426+
}
427+
428+
void update_vel(const ModuleBase::Vector3<double>* vel_in,
429+
const int ntype,
430+
const int nat,
431+
Atom* atoms)
432+
{
433+
int iat = 0;
434+
for (int it = 0; it < ntype; ++it)
435+
{
436+
Atom* atom = &atoms[it];
437+
for (int ia = 0; ia < atom->na; ++ia)
438+
{
439+
atoms[it].vel[ia] = vel_in[iat];
440+
++iat;
441+
}
442+
}
443+
assert(iat == nat);
444+
}
445+
377446
void periodic_boundary_adjustment(Atom* atoms,
378447
const ModuleBase::Matrix3& latvec,
379448
const int ntype)

source/module_cell/update_cell.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,48 @@ namespace unitcell
4848
const int ntype,
4949
const int nat,
5050
Atom* atoms);
51+
52+
/**
53+
* @brief update the position and tau of the atoms
54+
*
55+
* @param lat: the lattice of the atoms [in]
56+
* @param pos_in: the position of the atoms in direct coordinate system [in]
57+
* @param ntype: the number of types of the atoms [in]
58+
* @param nat: the number of atoms [in]
59+
* @param atoms: the atoms to be updated [out]
60+
*/
61+
void update_pos_taud(const Lattice& lat,
62+
const double* posd_in,
63+
const int ntype,
64+
const int nat,
65+
Atom* atoms);
66+
/**
67+
* @brief update the velocity of the atoms
68+
*
69+
* @param lat: the lattice of the atoms [in]
70+
* @param pos_in: the position of the atoms in direct coordinate system
71+
* in ModuleBase::Vector3 version [in]
72+
* @param ntype: the number of types of the atoms [in]
73+
* @param nat: the number of atoms [in]
74+
* @param atoms: the atoms to be updated [out]
75+
*/
76+
void update_pos_taud(const Lattice& lat,
77+
const ModuleBase::Vector3<double>* posd_in,
78+
const int ntype,
79+
const int nat,
80+
Atom* atoms);
81+
/**
82+
* @brief update the velocity of the atoms
83+
*
84+
* @param vel_in: the velocity of the atoms [in]
85+
* @param ntype: the number of types of the atoms [in]
86+
* @param nat: the number of atoms [in]
87+
* @param atoms: the atoms to be updated [out]
88+
*/
89+
void update_vel(const ModuleBase::Vector3<double>* vel_in,
90+
const int ntype,
91+
const int nat,
92+
Atom* atoms);
5193
}
5294
//
5395
#endif // UPDATE_CELL_H

source/module_elecstate/cal_dm.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
2727
//dm.fix_k(ik);
2828
dm[ik].create(ParaV->ncol, ParaV->nrow);
2929
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
30-
psi::Psi<double> wg_wfc(wfc, 1);
30+
psi::Psi<double> wg_wfc(1,
31+
wfc.get_nbands(),
32+
wfc.get_nbasis(),
33+
wfc.get_nbasis(),
34+
true);
35+
wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size());
3136

3237
int ib_global = 0;
3338
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
@@ -41,7 +46,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
4146
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
4247
}
4348
}
44-
if (ib_global >= wg.nc) continue;
49+
if (ib_global >= wg.nc) { continue;
50+
}
4551
const double wg_local = wg(ik, ib_global);
4652
double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
4753
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
@@ -99,7 +105,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
99105
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
100106
}
101107
}
102-
if (ib_global >= wg.nc) continue;
108+
if (ib_global >= wg.nc) { continue;
109+
}
103110
const double wg_local = wg(ik, ib_global);
104111
std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
105112
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);

0 commit comments

Comments
 (0)