Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion source/module_cell/bcast_cell.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "unitcell.h"

#include "module_base/parallel_common.h"
#include "module_parameter/parameter.h"
#ifdef __EXX
#include "module_ri/serialization_cereal.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#endif
namespace unitcell
{
void bcast_atoms_tau(Atom* atoms,
Expand All @@ -12,4 +17,109 @@ namespace unitcell
}
#endif
}

void bcast_atoms_pseudo(Atom* atoms,
const int ntype)
{
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
for (int i = 0; i < ntype; i++)
{
atoms[i].bcast_atom2();
}
#endif
}

void bcast_Lattice(Lattice& lat)
{
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
// distribute lattice parameters.
ModuleBase::Matrix3& latvec = lat.latvec;
ModuleBase::Matrix3& latvec_supercell = lat.latvec_supercell;
Parallel_Common::bcast_string(lat.Coordinate);
Parallel_Common::bcast_double(lat.lat0);
Parallel_Common::bcast_double(lat.lat0_angstrom);
Parallel_Common::bcast_double(lat.tpiba);
Parallel_Common::bcast_double(lat.tpiba2);
Parallel_Common::bcast_double(lat.omega);
Parallel_Common::bcast_string(lat.latName);

// distribute lattice vectors.
Parallel_Common::bcast_double(latvec.e11);
Parallel_Common::bcast_double(latvec.e12);
Parallel_Common::bcast_double(latvec.e13);
Parallel_Common::bcast_double(latvec.e21);
Parallel_Common::bcast_double(latvec.e22);
Parallel_Common::bcast_double(latvec.e23);
Parallel_Common::bcast_double(latvec.e31);
Parallel_Common::bcast_double(latvec.e32);
Parallel_Common::bcast_double(latvec.e33);

// distribute lattice vectors.
for (int i = 0; i < 3; i++)
{
Parallel_Common::bcast_double(lat.a1[i]);
Parallel_Common::bcast_double(lat.a2[i]);
Parallel_Common::bcast_double(lat.a3[i]);
Parallel_Common::bcast_double(lat.latcenter[i]);
Parallel_Common::bcast_int(lat.lc[i]);
}

// distribute superlattice vectors.
Parallel_Common::bcast_double(latvec_supercell.e11);
Parallel_Common::bcast_double(latvec_supercell.e12);
Parallel_Common::bcast_double(latvec_supercell.e13);
Parallel_Common::bcast_double(latvec_supercell.e21);
Parallel_Common::bcast_double(latvec_supercell.e22);
Parallel_Common::bcast_double(latvec_supercell.e23);
Parallel_Common::bcast_double(latvec_supercell.e31);
Parallel_Common::bcast_double(latvec_supercell.e32);
Parallel_Common::bcast_double(latvec_supercell.e33);

// distribute Change the lattice vectors or not
#endif
}

void bcast_magnetism(Magnetism& magnet, const int ntype)
{
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
Parallel_Common::bcast_double(magnet.start_magnetization, ntype);
if (PARAM.inp.nspin == 4)
{
Parallel_Common::bcast_double(magnet.ux_[0]);
Parallel_Common::bcast_double(magnet.ux_[1]);
Parallel_Common::bcast_double(magnet.ux_[2]);
}
#endif
}

void bcast_unitcell(UnitCell& ucell)
{
#ifdef __MPI
const int ntype = ucell.ntype;
Parallel_Common::bcast_int(ucell.nat);

bcast_Lattice(ucell.lat);
bcast_magnetism(ucell.magnet,ntype);
bcast_atoms_tau(ucell.atoms,ntype);

if(ucell.orbital_fn == nullptr)
{
ucell.orbital_fn = new std::string[ntype];
}
for (int i = 0; i < ntype; i++)
{
Parallel_Common::bcast_string(ucell.orbital_fn[i]);
}

#ifdef __EXX
ModuleBase::bcast_data_cereal(GlobalC::exx_info.info_ri.files_abfs,
MPI_COMM_WORLD,
0);
#endif
return;
#endif
}
}
40 changes: 40 additions & 0 deletions source/module_cell/bcast_cell.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
#ifndef BCAST_CELL_H
#define BCAST_CELL_H

#include "module_cell/unitcell.h"
namespace unitcell
{
/**
* @brief broadcast the tau array of the atoms
*
* @param atoms: the atoms to be broadcasted [in/out]
* @param ntype: the number of types of the atoms [in]
*/
void bcast_atoms_tau(Atom* atoms,
const int ntype);

/**
* @brief broadcast the pseduo of the atoms
*
* @param atoms: the atoms to be broadcasted [in/out]
* @param ntype: the number of types of the atoms [in]
*/
void bcast_atoms_pseudo(Atom* atoms,
const int ntype);
/**
* @brief broadcast the lattice
*
* @param lat: the lattice to be broadcasted [in/out]
*/
void bcast_Lattice(Lattice& lat);

/**
* @brief broadcast the magnetism
*
* @param magnet: the magnetism to be broadcasted [in/out]
* @param nytpe: the number of types of the atoms [in]
*/
void bcast_magnetism(Magnetism& magnet,
const int ntype);

/**
* @brief broadcast the unitcell
*
* @param ucell: the unitcell to be broadcasted [in/out]
*/
void bcast_unitcell(UnitCell& ucell);


}

#endif // BCAST_CELL_H
4 changes: 0 additions & 4 deletions source/module_cell/test/support/mock_unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ void UnitCell::bcast_atoms_tau() {}
bool UnitCell::judge_big_cell() const { return true; }
void UnitCell::update_stress(ModuleBase::matrix& scs) {}
void UnitCell::update_force(ModuleBase::matrix& fcs) {}
#ifdef __MPI
void UnitCell::bcast_unitcell() {}
void UnitCell::bcast_unitcell2() {}
#endif
void UnitCell::set_iat2itia() {}
void UnitCell::setup_cell(const std::string& fn, std::ofstream& log) {}
void UnitCell::read_orb_file(int it,
Expand Down
44 changes: 32 additions & 12 deletions source/module_cell/test/unitcell_test_para.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,28 @@ class UcellTest : public ::testing::Test
};

#ifdef __MPI
TEST_F(UcellTest, BcastUnitcell2)

TEST_F(UcellTest, BcastUnitcell)
{
elecstate::read_cell_pseudopots(pp_dir, ofs, *ucell);
ucell->bcast_unitcell2();
PARAM.input.nspin = 4;
unitcell::bcast_unitcell(*ucell);
if (GlobalV::MY_RANK != 0)
{
EXPECT_EQ(ucell->atoms[0].ncpp.nbeta, 4);
EXPECT_EQ(ucell->atoms[0].ncpp.nchi, 2);
EXPECT_EQ(ucell->atoms[1].ncpp.nbeta, 3);
EXPECT_EQ(ucell->atoms[1].ncpp.nchi, 1);
EXPECT_EQ(ucell->Coordinate, "Direct");
EXPECT_DOUBLE_EQ(ucell->a1.x, 10.0);
EXPECT_EQ(ucell->atoms[0].na, 1);
EXPECT_EQ(ucell->atoms[1].na, 2);
/// this is to ensure all processes have the atom label info
auto atom_labels = ucell->get_atomLabels();
std::string atom_type1_expected = "C";
std::string atom_type2_expected = "H";
EXPECT_EQ(atom_labels[0], atom_type1_expected);
EXPECT_EQ(atom_labels[1], atom_type2_expected);
}
}

TEST_F(UcellTest, BcastUnitcell)
TEST_F(UcellTest, BcastLattice)
{
PARAM.input.nspin = 4;
ucell->bcast_unitcell();
unitcell::bcast_Lattice(ucell->lat);
if (GlobalV::MY_RANK != 0)
{
EXPECT_EQ(ucell->Coordinate, "Direct");
Expand All @@ -125,6 +130,22 @@ TEST_F(UcellTest, BcastUnitcell)
EXPECT_EQ(atom_labels[1], atom_type2_expected);
}
}

TEST_F(UcellTest, BcastMagnitism)
{
unitcell::bcast_magnetism(ucell->magnet, ucell->ntype);
PARAM.input.nspin = 4;
if (GlobalV::MY_RANK != 0)
{
EXPECT_DOUBLE_EQ(ucell->magnet.start_magnetization[0], 0.0);
EXPECT_DOUBLE_EQ(ucell->magnet.start_magnetization[1], 0.0);
for (int i = 0; i < 3; ++i)
{
EXPECT_DOUBLE_EQ(ucell->magnet.ux_[i], 0.0);
}
}
}

TEST_F(UcellTest, UpdatePosTau)
{
double* pos_in = new double[ucell->nat * 3];
Expand Down Expand Up @@ -204,7 +225,6 @@ TEST_F(UcellTest, ReadPseudo)
EXPECT_EQ(error2, 0);
}
// read_cell_pseudopots
// bcast_unitcell2
EXPECT_FALSE(ucell->atoms[0].ncpp.has_so);
EXPECT_FALSE(ucell->atoms[1].ncpp.has_so);
EXPECT_EQ(ucell->atoms[0].ncpp.nbeta, 4);
Expand Down
Loading
Loading