Skip to content

Commit e93d3a7

Browse files
authored
perf the compute of check_atom_stru and add openmp (#5962)
* first version * change add * add the change * change * change gtest * add change * change the file * add format * modify back test * add the change in the check_atom_stru * move openmp * add globlv for the question * delete vector * update bug
1 parent c8204a8 commit e93d3a7

File tree

2 files changed

+156
-140
lines changed

2 files changed

+156
-140
lines changed

source/module_cell/check_atomic_stru.cpp

Lines changed: 139 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,165 @@
11
#include "check_atomic_stru.h"
22

33
#include "module_base/element_covalent_radius.h"
4+
#include "module_base/timer.h"
45

5-
void Check_Atomic_Stru::check_atomic_stru(UnitCell& ucell, const double& factor) {
6+
void Check_Atomic_Stru::check_atomic_stru(UnitCell& ucell, const double& factor)
7+
{
68
// First we calculate all bond length in the structure,
79
// and compare with the covalent_bond_length,
810
// if there has bond length is shorter than covalent_bond_length * factor,
911
// we think this structure is unreasonable.
10-
const double warning_coef = 0.6;
1112
assert(ucell.ntype > 0);
12-
std::stringstream errorlog;
1313
bool all_pass = true;
1414
bool no_warning = true;
15-
for (int it1 = 0; it1 < ucell.ntype; it1++) {
16-
std::string symbol1 = "";
17-
for (char ch: ucell.atoms[it1].label) {
18-
if (std::isalpha(ch)) {
19-
symbol1.push_back(ch);
20-
}
21-
}
22-
// std::string symbol1 = ucell.atoms[it1].label;
23-
double symbol1_covalent_radius;
24-
if (ModuleBase::CovalentRadius.find(symbol1)
25-
!= ModuleBase::CovalentRadius.end()) {
26-
symbol1_covalent_radius = ModuleBase::CovalentRadius.at(symbol1);
27-
} else {
28-
std::stringstream mess;
29-
mess << "Notice: symbol '" << symbol1
30-
<< "' is not an element symbol!!!! ";
31-
mess << "set the covalent radius to be 0." << std::endl;
32-
GlobalV::ofs_running << mess.str();
33-
std::cout << mess.str();
34-
symbol1_covalent_radius = 0.0;
35-
}
15+
std::stringstream errorlog;
16+
errorlog.setf(std::ios_base::fixed, std::ios_base::floatfield);
3617

37-
for (int ia1 = 0; ia1 < ucell.atoms[it1].na; ia1++) {
38-
double x1 = ucell.atoms[it1].taud[ia1].x;
39-
double y1 = ucell.atoms[it1].taud[ia1].y;
40-
double z1 = ucell.atoms[it1].taud[ia1].z;
18+
if (GlobalV::MY_RANK == 0)
19+
{
20+
ModuleBase::timer::tick("Check_Atomic_Stru", "Check_Atomic_Stru");
21+
22+
const int ntype = ucell.ntype;
23+
const double lat0 = ucell.lat0;
24+
const double warning_coef = 0.6;
25+
const double max_factor_coef = std::max(warning_coef, factor);
4126

42-
for (int it2 = 0; it2 < ucell.ntype; it2++) {
43-
std::string symbol2 = ucell.atoms[it2].label;
44-
double symbol2_covalent_radius;
45-
if (ModuleBase::CovalentRadius.find(symbol2)
46-
!= ModuleBase::CovalentRadius.end()) {
47-
symbol2_covalent_radius
48-
= ModuleBase::CovalentRadius.at(symbol2);
49-
} else {
50-
symbol2_covalent_radius = 0.0;
27+
std::vector<double> symbol_covalent_radiuss(ntype);
28+
for (int it = 0; it < ntype; it++)
29+
{
30+
std::string symbol1 = "";
31+
for (char ch: ucell.atoms[it].label)
32+
{
33+
if (std::isalpha(ch))
34+
{
35+
symbol1.push_back(ch);
5136
}
37+
}
5238

53-
double covalent_length
54-
= (symbol1_covalent_radius + symbol2_covalent_radius)
55-
/ ModuleBase::BOHR_TO_A;
56-
57-
for (int ia2 = 0; ia2 < ucell.atoms[it2].na; ia2++) {
58-
for (int a = -1; a < 2; a++) {
59-
for (int b = -1; b < 2; b++) {
60-
for (int c = -1; c < 2; c++) {
61-
if (it1 > it2) {
62-
continue;
63-
} else if (it1 == it2 && ia1 > ia2) {
64-
continue;
65-
} else if (it1 == it2 && ia1 == ia2 && a == 0
66-
&& b == 0 && c == 0) {
67-
continue;
68-
}
69-
70-
double x2 = ucell.atoms[it2].taud[ia2].x + a;
71-
double y2 = ucell.atoms[it2].taud[ia2].y + b;
72-
double z2 = ucell.atoms[it2].taud[ia2].z + c;
73-
74-
double bond_length
75-
= sqrt(pow((x2 - x1) * ucell.a1.x
76-
+ (y2 - y1) * ucell.a2.x
77-
+ (z2 - z1) * ucell.a3.x,
78-
2)
79-
+ pow((x2 - x1) * ucell.a1.y
80-
+ (y2 - y1) * ucell.a2.y
81-
+ (z2 - z1) * ucell.a3.y,
82-
2)
83-
+ pow((x2 - x1) * ucell.a1.z
84-
+ (y2 - y1) * ucell.a2.z
85-
+ (z2 - z1) * ucell.a3.z,
86-
2))
87-
* ucell.lat0;
88-
89-
if (bond_length < covalent_length * factor
90-
|| bond_length
91-
< covalent_length * warning_coef) {
92-
errorlog.setf(std::ios_base::fixed,
93-
std::ios_base::floatfield);
94-
errorlog << std::setw(3) << ia1 + 1
95-
<< "-th " << std::setw(3)
96-
<< ucell.atoms[it1].label << ", ";
97-
errorlog << std::setw(3) << ia2 + 1
98-
<< "-th " << std::setw(3)
99-
<< ucell.atoms[it2].label;
100-
errorlog << " (cell:" << std::setw(2) << a
101-
<< " " << std::setw(2) << b << " "
102-
<< std::setw(2) << c << ")";
103-
errorlog << ", distance= "
104-
<< std::setprecision(3)
105-
<< bond_length << " Bohr (";
106-
errorlog
107-
<< bond_length * ModuleBase::BOHR_TO_A
108-
<< " Angstrom)" << std::endl;
39+
if (ModuleBase::CovalentRadius.find(symbol1) != ModuleBase::CovalentRadius.end())
40+
{
41+
symbol_covalent_radiuss[it] = ModuleBase::CovalentRadius.at(symbol1);
42+
}
43+
else
44+
{
45+
std::stringstream mess;
46+
mess << "Notice: symbol '" << symbol1 << "' is not an element symbol!!!! ";
47+
mess << "set the covalent radius to be 0." << std::endl;
48+
GlobalV::ofs_running << mess.str();
49+
std::cout << mess.str();
50+
}
51+
}
52+
std::vector<double> latvec (9);
53+
latvec[0] = ucell.a1.x;
54+
latvec[1] = ucell.a2.x;
55+
latvec[2] = ucell.a3.x;
56+
latvec[3] = ucell.a1.y;
57+
latvec[4] = ucell.a2.y;
58+
latvec[5] = ucell.a3.y;
59+
latvec[6] = ucell.a1.z;
60+
latvec[7] = ucell.a2.z;
61+
latvec[8] = ucell.a3.z;
62+
std::vector<double> A(27*3);
63+
std::vector<std::string> cell(27);
64+
std::vector<std::string> label(ntype);
65+
for (int i = 0; i < 27; i++)
66+
{
67+
int a = (i / 9) % 3 - 1;
68+
int b = (i / 3) % 3 - 1;
69+
int c = i % 3 - 1;
70+
A[3 * i] = a * latvec[0] + b * latvec[1] + c * latvec[2];
71+
A[3 * i + 1] = a * latvec[3] + b * latvec[4] + c * latvec[5];
72+
A[3 * i + 2] = a * latvec[6] + b * latvec[7] + c * latvec[8];
73+
std::ostringstream tmp_oss;
74+
tmp_oss << " (cell:" << std::setw(2) << a << " " << std::setw(2) << b << " " << std::setw(2) << c
75+
<< "), distance= ";
76+
cell[i] = tmp_oss.str();
77+
}
78+
for (int it = 0; it < ntype; it++)
79+
{
80+
std::ostringstream tmp_oss;
81+
tmp_oss << std::setw(3) << ucell.atoms[it].label;
82+
label[it] = tmp_oss.str();
83+
}
10984

110-
if (bond_length
111-
< covalent_length * factor) {
112-
all_pass = false;
113-
} else {
114-
no_warning = false;
115-
}
85+
const double bohr_to_a = ModuleBase::BOHR_TO_A;
86+
#pragma omp parallel
87+
{
88+
std::vector<double> delta_lat(3);
89+
#pragma omp for schedule(dynamic)
90+
for (int iat = 0; iat < ucell.nat; iat++)
91+
{
92+
const int it1 = ucell.iat2it[iat];
93+
const int ia1 = ucell.iat2ia[iat];
94+
const double symbol1_covalent_radius = symbol_covalent_radiuss[it1];
95+
double x1 = ucell.atoms[it1].taud[ia1].x;
96+
double y1 = ucell.atoms[it1].taud[ia1].y;
97+
double z1 = ucell.atoms[it1].taud[ia1].z;
98+
for (int it2 = it1; it2 < ntype; it2++)
99+
{
100+
double symbol2_covalent_radius = symbol_covalent_radiuss[it2];
101+
double covalent_length = (symbol1_covalent_radius + symbol2_covalent_radius) / bohr_to_a;
102+
const double max_error = covalent_length * max_factor_coef / ucell.lat0;
103+
const double max_error_2 = max_error * max_error;
104+
const double factor_error = covalent_length * factor;
105+
for (int ia2 = ia1; ia2 < ucell.atoms[it2].na; ia2++)
106+
{
107+
const bool is_same_atom = (it1 == it2) && (ia1 == ia2);
108+
double delta_x = ucell.atoms[it2].taud[ia2].x - x1;
109+
double delta_y = ucell.atoms[it2].taud[ia2].y - y1;
110+
double delta_z = ucell.atoms[it2].taud[ia2].z - z1;
111+
delta_lat[0] = delta_x * latvec[0] + delta_y * latvec[1] + delta_z * latvec[2];
112+
delta_lat[1] = delta_x * latvec[3] + delta_y * latvec[4] + delta_z * latvec[5];
113+
delta_lat[2] = delta_x * latvec[6] + delta_y * latvec[7] + delta_z * latvec[8];
114+
for (int i = 0; i < 27; i++)
115+
{
116+
if ((is_same_atom) && (i == 13))
117+
continue;
118+
const int offset = i * 3;
119+
const double part1 = delta_lat[0] + A[offset];
120+
const double part2 = delta_lat[1] + A[offset + 1];
121+
const double part3 = delta_lat[2] + A[offset + 2];
122+
const double bond_length = part1 * part1 + part2 * part2 + part3 * part3;
123+
const bool flag = bond_length < max_error_2 ? true : false;
124+
if (flag)
125+
{
126+
const double sqrt_bon = sqrt(bond_length) * lat0;
127+
#pragma omp critical
128+
{
129+
no_warning = false;
130+
all_pass = all_pass && (sqrt_bon < factor_error ? false : true);
131+
errorlog << std::setw(3) << ia1 + 1 << "-th " << label[it1] << ", " << std::setw(3)
132+
<< ia2 + 1 << "-th " << label[it2] << cell[i] << std::setprecision(3)
133+
<< sqrt_bon << " Bohr (" << sqrt_bon * bohr_to_a << " Angstrom)\n";
116134
}
117-
} // c
118-
} // b
119-
} // a
120-
} // ia2
121-
} // it2
122-
} // ia1
123-
} // it1
124-
125-
if (!all_pass || !no_warning) {
135+
}
136+
}
137+
} // ia2
138+
} // it2
139+
} // iat
140+
}
141+
ModuleBase::timer::tick("Check_Atomic_Stru", "Check_Atomic_Stru");
142+
}
143+
if (!all_pass || !no_warning)
144+
{
126145
std::stringstream mess;
127-
mess << "\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%"
128-
<< std::endl;
129-
mess << "%%%%%% WARNING WARNING WARNING WARNING WARNING %%%%%%"
130-
<< std::endl;
131-
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%"
132-
<< std::endl;
146+
mess << "\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
147+
mess << "%%%%%% WARNING WARNING WARNING WARNING WARNING %%%%%%" << std::endl;
148+
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
133149
mess << "!!! WARNING: Some atoms are too close!!!" << std::endl;
134-
mess << "!!! Please check the nearest-neighbor list in log file."
135-
<< std::endl;
136-
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%"
137-
<< std::endl;
138-
mess << "%%%%%% WARNING WARNING WARNING WARNING WARNING %%%%%%"
139-
<< std::endl;
140-
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%"
141-
<< std::endl;
150+
mess << "!!! Please check the nearest-neighbor list in log file." << std::endl;
151+
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
152+
mess << "%%%%%% WARNING WARNING WARNING WARNING WARNING %%%%%%" << std::endl;
153+
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
142154

143-
GlobalV::ofs_running << mess.str() << mess.str() << mess.str()
144-
<< errorlog.str();
155+
GlobalV::ofs_running << mess.str() << mess.str() << mess.str() << errorlog.str();
145156
std::cout << mess.str() << mess.str() << mess.str() << std::endl;
146-
147-
if (!all_pass) {
157+
if (!all_pass)
158+
{
148159
mess.clear();
149160
mess.str("");
150-
mess << "If this structure is what you want, you can set "
151-
"'min_dist_coef'"
152-
<< std::endl;
153-
mess << "as a smaller value (the current value is " << factor
154-
<< ") in INPUT file." << std::endl;
161+
mess << "If this structure is what you want, you can set 'min_dist_coef'\n";
162+
mess << "as a smaller value (the current value is " << factor << ") in INPUT file." << std::endl;
155163
GlobalV::ofs_running << mess.str();
156164
std::cout << mess.str();
157165
ModuleBase::WARNING_QUIT("Input", "The structure is unreasonable!");

source/module_cell/test/unitcell_test_readpp.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "module_elecstate/read_pseudo.h"
1212
#include <valarray>
1313
#include <vector>
14+
#include "string.h"
1415
#ifdef __MPI
1516
#include "mpi.h"
1617
#endif
@@ -341,38 +342,45 @@ TEST_F(UcellDeathTest, CheckStructure) {
341342
EXPECT_FALSE(ucell->atoms[0].ncpp.has_so);
342343
EXPECT_FALSE(ucell->atoms[1].ncpp.has_so);
343344
// trial 1
345+
344346
testing::internal::CaptureStdout();
345347
double factor = 0.2;
348+
ucell->set_iat2itia();
346349
EXPECT_NO_THROW(Check_Atomic_Stru::check_atomic_stru(*ucell, factor));
347350
output = testing::internal::GetCapturedStdout();
348-
EXPECT_THAT(output,
349-
testing::HasSubstr("WARNING: Some atoms are too close!!!"));
351+
EXPECT_THAT(output,testing::HasSubstr("WARNING: Some atoms are too close!!!"));
350352
// trial 2
351-
testing::internal::CaptureStdout();
353+
GlobalV::ofs_running.open("CheckStructure2.txt");
354+
::testing::FLAGS_gtest_death_test_style = "threadsafe";
352355
factor = 0.4;
353356
EXPECT_EXIT(Check_Atomic_Stru::check_atomic_stru(*ucell, factor),
354357
::testing::ExitedWithCode(1),
355358
"");
356-
output = testing::internal::GetCapturedStdout();
359+
std::ifstream ifs("CheckStructure2.txt");
360+
if (ifs.is_open())
361+
{
362+
std::string line;
363+
while (std::getline(ifs, line)) {
364+
output+=line;
365+
}
366+
}
357367
EXPECT_THAT(output, testing::HasSubstr("The structure is unreasonable!"));
368+
GlobalV::ofs_running.open("running.log");
358369
// trial 3
359370
ucell->atoms[0].label = "arbitrary";
360371
testing::internal::CaptureStdout();
361372
factor = 0.2;
362373
EXPECT_NO_THROW(Check_Atomic_Stru::check_atomic_stru(*ucell, factor));
363374
output = testing::internal::GetCapturedStdout();
364-
EXPECT_THAT(
365-
output,
366-
testing::HasSubstr("Notice: symbol 'arbitrary' is not an element "
375+
EXPECT_THAT(output,testing::HasSubstr("Notice: symbol 'arbitrary' is not an element "
367376
"symbol!!!! set the covalent radius to be 0."));
368377
// trial 4
369378
ucell->atoms[0].label = "Fe1";
370379
testing::internal::CaptureStdout();
371380
factor = 0.2;
372381
EXPECT_NO_THROW(Check_Atomic_Stru::check_atomic_stru(*ucell, factor));
373382
output = testing::internal::GetCapturedStdout();
374-
EXPECT_THAT(output,
375-
testing::HasSubstr("WARNING: Some atoms are too close!!!"));
383+
EXPECT_THAT(output,testing::HasSubstr("WARNING: Some atoms are too close!!!"));
376384
}
377385

378386
TEST_F(UcellDeathTest, ReadPseudoWarning1) {

0 commit comments

Comments
 (0)