Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions source/module_cell/atom_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Atom
std::vector<bool> iw2_new;
int nw = 0; // number of local orbitals (l,n,m) of this type

void set_index(void);
void set_index();

int type = 0; // Index of atom type
int na = 0; // Number of atoms in this type.
Expand All @@ -34,8 +34,7 @@ class Atom

std::string label = "\0"; // atomic symbol
std::vector<ModuleBase::Vector3<double>> tau; // Cartesian coordinates of each atom in this type.
std::vector<ModuleBase::Vector3<double>>
dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
std::vector<ModuleBase::Vector3<double>> dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
std::vector<ModuleBase::Vector3<double>> taud; // Direct coordinates of each atom in this type.
std::vector<ModuleBase::Vector3<double>> vel; // velocities of each atom in this type.
std::vector<ModuleBase::Vector3<double>> force; // force acting on each atom in this type.
Expand All @@ -54,8 +53,8 @@ class Atom
void print_Atom(std::ofstream& ofs);
void update_force(ModuleBase::matrix& fcs);
#ifdef __MPI
void bcast_atom(void);
void bcast_atom2(void);
void bcast_atom();
void bcast_atom2();
#endif
};

Expand Down
5 changes: 1 addition & 4 deletions source/module_cell/test/support/mock_unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ bool UnitCell::read_atom_positions(std::ifstream& ifpos,
std::ofstream& ofs_warning) {
return true;
}
void UnitCell::update_pos_taud(double* posd_in) {}
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {}
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {}
void UnitCell::bcast_atoms_tau() {}

bool UnitCell::judge_big_cell() const { return true; }
void UnitCell::update_stress(ModuleBase::matrix& scs) {}
void UnitCell::update_force(ModuleBase::matrix& fcs) {}
Expand Down
2 changes: 1 addition & 1 deletion source/module_cell/test/unitcell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ TEST_F(UcellTest, UpdateVel)
{
vel_in[iat].set(iat * 0.1, iat * 0.1, iat * 0.1);
}
ucell->update_vel(vel_in);
unitcell::update_vel(vel_in,ucell->ntype,ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
EXPECT_DOUBLE_EQ(vel_in[iat].x, 0.1 * iat);
Expand Down
36 changes: 34 additions & 2 deletions source/module_cell/test/unitcell_test_para.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ TEST_F(UcellTest, UpdatePosTau)
}
delete[] pos_in;
}
TEST_F(UcellTest, UpdatePosTaud)
TEST_F(UcellTest, UpdatePosTaud_pointer)
{
double* pos_in = new double[ucell->nat * 3];
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
Expand All @@ -167,7 +167,8 @@ TEST_F(UcellTest, UpdatePosTaud)
ucell->iat2iait(iat, &ia, &it);
tmp[iat] = ucell->atoms[it].taud[ia];
}
ucell->update_pos_taud(pos_in);
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
int it, ia;
Expand All @@ -180,6 +181,37 @@ TEST_F(UcellTest, UpdatePosTaud)
delete[] pos_in;
}

//test update_pos_taud with ModuleBase::Vector3<double> version
TEST_F(UcellTest, UpdatePosTaud_Vector3)
{
ModuleBase::Vector3<double>* pos_in = new ModuleBase::Vector3<double>[ucell->nat];
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
ucell->set_iat2itia();
for (int iat = 0; iat < ucell->nat; ++iat)
{
for (int ik = 0; ik < 3; ++ik)
{
pos_in[iat][ik] = 0.01;
}
int it=0;
int ia=0;
ucell->iat2iait(iat, &ia, &it);
tmp[iat] = ucell->atoms[it].taud[ia];
}
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
int it, ia;
ucell->iat2iait(iat, &ia, &it);
for (int ik = 0; ik < 3; ++ik)
{
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia][ik], tmp[iat][ik] + 0.01);
}
}
delete[] tmp;
delete[] pos_in;
}
TEST_F(UcellTest, ReadPseudo)
{
PARAM.input.pseudo_dir = pp_dir;
Expand Down
59 changes: 0 additions & 59 deletions source/module_cell/unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,65 +314,6 @@ std::vector<ModuleBase::Vector3<int>> UnitCell::get_constrain() const
return constrain;
}



void UnitCell::update_pos_taud(double* posd_in) {
int iat = 0;
for (int it = 0; it < this->ntype; it++) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ia++) {
for (int ik = 0; ik < 3; ++ik) {
atom->taud[ia][ik] += posd_in[3 * iat + ik];
atom->dis[ia][ik] = posd_in[3 * iat + ik];
}
iat++;
}
}
assert(iat == this->nat);
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
this->bcast_atoms_tau();
}

// posd_in is atomic displacements here liuyu 2023-03-22
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {
int iat = 0;
for (int it = 0; it < this->ntype; it++) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ia++) {
for (int ik = 0; ik < 3; ++ik) {
atom->taud[ia][ik] += posd_in[iat][ik];
atom->dis[ia][ik] = posd_in[iat][ik];
}
iat++;
}
}
assert(iat == this->nat);
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
this->bcast_atoms_tau();
}

void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {
int iat = 0;
for (int it = 0; it < this->ntype; ++it) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ++ia) {
this->atoms[it].vel[ia] = vel_in[iat];
++iat;
}
}
assert(iat == this->nat);
}


void UnitCell::bcast_atoms_tau() {
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
for (int i = 0; i < ntype; i++) {
atoms[i].bcast_atom(); // bcast tau array
}
#endif
}

//==============================================================
// Calculate various lattice related quantities for given latvec
//==============================================================
Expand Down
4 changes: 0 additions & 4 deletions source/module_cell/unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ class UnitCell {
void print_cell(std::ofstream& ofs) const;
void print_cell_xyz(const std::string& fn) const;

void update_pos_taud(const ModuleBase::Vector3<double>* posd_in);
void update_pos_taud(double* posd_in);
void update_vel(const ModuleBase::Vector3<double>* vel_in);
void bcast_atoms_tau();
bool judge_big_cell() const;

void update_stress(ModuleBase::matrix& scs); // updates stress
Expand Down
69 changes: 69 additions & 0 deletions source/module_cell/update_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,75 @@ void update_pos_tau(const Lattice& lat,
bcast_atoms_tau(atoms, ntype);
}

void update_pos_taud(const Lattice& lat,
const double* posd_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; it++)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ia++)
{
for (int ik = 0; ik < 3; ++ik)
{
atom->taud[ia][ik] += posd_in[3 * iat + ik];
atom->dis[ia][ik] = posd_in[3 * iat + ik];
}
iat++;
}
}
assert(iat == nat);
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
bcast_atoms_tau(atoms, ntype);
}

// posd_in is atomic displacements here liuyu 2023-03-22
void update_pos_taud(const Lattice& lat,
const ModuleBase::Vector3<double>* posd_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; it++)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ia++)
{
for (int ik = 0; ik < 3; ++ik)
{
atom->taud[ia][ik] += posd_in[iat][ik];
atom->dis[ia][ik] = posd_in[iat][ik];
}
iat++;
}
}
assert(iat == nat);
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
bcast_atoms_tau(atoms, ntype);
}

void update_vel(const ModuleBase::Vector3<double>* vel_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; ++it)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ++ia)
{
atoms[it].vel[ia] = vel_in[iat];
++iat;
}
}
assert(iat == nat);
}

void periodic_boundary_adjustment(Atom* atoms,
const ModuleBase::Matrix3& latvec,
const int ntype)
Expand Down
42 changes: 42 additions & 0 deletions source/module_cell/update_cell.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,48 @@ namespace unitcell
const int ntype,
const int nat,
Atom* atoms);

/**
* @brief update the position and tau of the atoms
*
* @param lat: the lattice of the atoms [in]
* @param pos_in: the position of the atoms in direct coordinate system [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_pos_taud(const Lattice& lat,
const double* posd_in,
const int ntype,
const int nat,
Atom* atoms);
/**
* @brief update the velocity of the atoms
*
* @param lat: the lattice of the atoms [in]
* @param pos_in: the position of the atoms in direct coordinate system
* in ModuleBase::Vector3 version [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_pos_taud(const Lattice& lat,
const ModuleBase::Vector3<double>* posd_in,
const int ntype,
const int nat,
Atom* atoms);
/**
* @brief update the velocity of the atoms
*
* @param vel_in: the velocity of the atoms [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_vel(const ModuleBase::Vector3<double>* vel_in,
const int ntype,
const int nat,
Atom* atoms);
}
//
#endif // UPDATE_CELL_H
5 changes: 2 additions & 3 deletions source/module_md/md_base.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#include "md_base.h"

#include "md_func.h"
#ifdef __MPI
#include "mpi.h"
#endif
#include "module_io/print_info.h"

#include "module_cell/update_cell.h"
MD_base::MD_base(const Parameter& param_in, UnitCell& unit_in) : mdp(param_in.mdp), ucell(unit_in)
{
my_rank = param_in.globalv.myrank;
Expand Down Expand Up @@ -112,7 +111,7 @@ void MD_base::update_pos()
MPI_Bcast(pos, ucell.nat * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD);
#endif

ucell.update_pos_taud(pos);
unitcell::update_pos_taud(ucell.lat,pos,ucell.ntype,ucell.nat,ucell.atoms);

return;
}
Expand Down
4 changes: 2 additions & 2 deletions source/module_md/run_md.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "msst.h"
#include "nhchain.h"
#include "verlet.h"

#include "module_cell/update_cell.h"
namespace Run_MD
{

Expand Down Expand Up @@ -97,7 +97,7 @@ void md_line(UnitCell& unit_in, ModuleESolver::ESolver* p_esolver, const Paramet

if ((mdrun->step_ + mdrun->step_rst_) % param_in.mdp.md_restartfreq == 0)
{
unit_in.update_vel(mdrun->vel);
unitcell::update_vel(mdrun->vel,unit_in.ntype,unit_in.nat,unit_in.atoms);
std::stringstream file;
file << PARAM.globalv.global_stru_dir << "STRU_MD_" << mdrun->step_ + mdrun->step_rst_;
// changelog 20240509
Expand Down
2 changes: 1 addition & 1 deletion source/module_relax/relax_new/relax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ void Relax::move_cell_ions(UnitCell& ucell, const bool is_new_dir)
ucell.symm.symmetrize_vec3_nat(move_ion);
}

ucell.update_pos_taud(move_ion);
unitcell::update_pos_taud(ucell.lat,move_ion,ucell.ntype,ucell.nat,ucell.atoms);

// Print the structure file.
ucell.print_tau();
Expand Down
Loading
Loading