Skip to content

Commit 9920ccd

Browse files
wenfei-liwenfei-li
andauthored
Test : UT for relax_new (#2093)
* Test : brent line search UT * test : prepare for relax * modify brent * add some comments to line search * change back line_search * update test of line search * Test : UT for setup_gradient * Test : UT for relax_step --------- Co-authored-by: wenfei-li <[email protected]>
1 parent aa8358c commit 9920ccd

File tree

12 files changed

+1908
-52
lines changed

12 files changed

+1908
-52
lines changed

source/module_relax/relax_driver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ void Relax_Driver::relax_driver(ModuleESolver::ESolver *p_esolver)
1616
}
1717
else
1818
{
19-
rl.init_relax(GlobalC::ucell.nat);
19+
rl.init_relax(GlobalC::ucell.nat, Lattice_Change_Basic::out_stru);
2020
}
2121
}
2222

source/module_relax/relax_new/line_search.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,15 +276,16 @@ bool Line_Search::brent(
276276

277277
xnew = xb;
278278
if(std::abs(dy)<conv_thr) return true;
279-
if(ls_step == 4)
279+
if(ls_step == 4) //I'm not sure if this is a good choice, but the idea is there should not be so many line search steps
280+
//I feel if the line search does not converge, we'd better change the direction and restart line search
280281
{
281-
GlobalV::ofs_running << "Too much Brent steps, let's do next CG step" << std::endl;
282+
GlobalV::ofs_running << "Too many Brent steps, let's do next CG step" << std::endl;
282283
return true;
283284
//ModuleBase::WARNING_QUIT("Brent","too many steps in line search, something wrong");
284285
}
285286

286287
return false;
287-
}//end ibrack
288+
}//end bracked
288289
else
289290
{
290291
if(!( (xa<=xb && xb<=xc) || (xc<=xb && xb<=xa) ))
@@ -380,7 +381,7 @@ bool Line_Search::brent(
380381
if(std::abs(dy)<conv_thr) return true;
381382
if(ls_step == 4)
382383
{
383-
GlobalV::ofs_running << "Too much Brent steps, let's do next CG step" << std::endl;
384+
GlobalV::ofs_running << "Too many Brent steps, let's do next CG step" << std::endl;
384385
return true;
385386
//ModuleBase::WARNING_QUIT("Brent","too many steps in line search, something wrong");
386387
}

source/module_relax/relax_new/relax.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
#include "module_base/tool_title.h"
88
#include "module_base/parallel_common.h"
99

10-
void Relax::init_relax(const int nat_in)
10+
void Relax::init_relax(const int nat_in, const int out_stru_in)
1111
{
1212
ModuleBase::TITLE("Relax","init_relax");
1313

1414
//set some initial conditions / constants
1515
nat = nat_in;
16+
out_stru = out_stru_in;
1617
istep = 0;
1718
cg_step = 0;
1819
ltrial = false;
@@ -594,7 +595,7 @@ void Relax::move_cell_ions(const bool is_new_dir)
594595
// =================================================================
595596
std::stringstream ss;
596597
ss << GlobalV::global_out_dir << "STRU_ION";
597-
if (Lattice_Change_Basic::out_stru == 1)
598+
if (out_stru == 1)
598599
{
599600
ss << istep;
600601
istep ++;

source/module_relax/relax_new/relax.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class Relax
1515
~Relax(){};
1616

1717
//prepare for relaxation
18-
void init_relax(const int nat);
18+
void init_relax(const int nat_in, const int out_stru_in);
1919
//perform a single relaxation step
2020
bool relax_step(const ModuleBase::matrix& force, const ModuleBase::matrix &stress, const double etot_in);
2121

@@ -42,6 +42,7 @@ class Relax
4242
void move_cell_ions(const bool is_new_dir);
4343

4444
int nat; // number of atoms
45+
int out_stru;
4546
bool ltrial; // if last step is trial step
4647

4748
double step_size;

source/module_relax/relax_new/test/CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,19 @@ remove_definitions(-D__LCAO)
33
remove_definitions(-D__DEEPKS)
44
remove_definitions(-D__CUDA)
55
remove_definitions(-D__ROCM)
6+
7+
install(DIRECTORY support DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
8+
69
AddTest(
710
TARGET relax_new_line_search
8-
SOURCES line_search_test.cpp ../line_search.cpp ../../../module_base/tool_quit.cpp ../../../module_base/global_variable.cpp ../../../module_base/global_file.cpp ../../../module_base/memory.cpp ../../../module_base/timer.cpp
11+
SOURCES line_search_test.cpp ../line_search.cpp ../../../module_base/global_variable.cpp ../../../module_base/global_file.cpp ../../../module_base/memory.cpp ../../../module_base/timer.cpp
912
)
13+
14+
AddTest(
15+
TARGET relax_new_relax
16+
SOURCES relax_test.cpp ../relax.cpp ../line_search.cpp ../../../module_base/tool_quit.cpp ../../../module_base/global_variable.cpp ../../../module_base/global_file.cpp ../../../module_base/memory.cpp ../../../module_base/timer.cpp
17+
../../../module_base/matrix3.cpp ../../../module_base/intarray.cpp ../../../module_base/tool_title.cpp ../../../module_io/input.cpp
18+
../../../module_base/global_function.cpp ../../../module_base/complexmatrix.cpp ../../../module_base/matrix.cpp
19+
../../../module_base/complexarray.cpp
20+
LIBS ${math_libs}
21+
)
Lines changed: 169 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "gtest/gtest.h"
2-
2+
#include "../line_search.h"
33
/************************************************
44
* unit test of class Line_Search
55
***********************************************/
@@ -13,57 +13,183 @@
1313
* - Line_Search::update_brent()
1414
* - Line_Search::brent()
1515
*/
16+
namespace ModuleBase
17+
{
18+
void WARNING_QUIT(const std::string &file,const std::string &description) {return;}
19+
}
20+
21+
// test of all cases of third_order()
22+
class Test_TO : public testing::Test
23+
{
24+
protected:
25+
std::vector<double> xnew_arr;
1626

27+
void SetUp()
28+
{
29+
Line_Search ls;
1730

18-
#define private public
19-
#include "module_relax/relax_new/line_search.h"
31+
double x,xnew,y,f;
2032

21-
class LineSearchPrepare
22-
{
23-
public:
24-
LineSearchPrepare(
25-
bool restart_in,
26-
double x_in,
27-
double y_in,
28-
double f_in,
29-
double xnew_in,
30-
double conv_thr_in):
31-
restart(restart_in),
32-
x(x_in),
33-
y(y_in),
34-
f(f_in),
35-
xnew(xnew_in),
36-
conv_thr(conv_thr_in)
37-
{}
38-
bool restart;
39-
double x;
40-
double y;
41-
double f;
42-
double xnew;
43-
double conv_thr;
44-
};
33+
// 3rd order, harmonic, dmove > 0
34+
x=0; y=0; f=-2;
35+
ls.line_search(1,x,y,f,xnew,1e-10);
36+
xnew_arr.push_back(xnew);
37+
38+
x=1; y=0; f=2;
39+
ls.line_search(0,x,y,f,xnew,1e-10);
40+
xnew_arr.push_back(xnew);
41+
42+
// 3rd order, harmonic, dmove < 0
43+
x=0; y=0; f=1;
44+
ls.line_search(1,x,y,f,xnew,1e-10);
45+
xnew_arr.push_back(xnew);
46+
47+
x=1; y=2; f=3;
48+
ls.line_search(0,x,y,f,xnew,1e-10);
49+
xnew_arr.push_back(xnew);
50+
51+
// 3rd order, harmonic, dmove > 4
52+
x=0; y=0; f=-20;
53+
ls.line_search(1,x,y,f,xnew,1e-10);
54+
xnew_arr.push_back(xnew);
55+
56+
x=1; y=-9; f=-18;
57+
ls.line_search(0,x,y,f,xnew,1e-10);
58+
xnew_arr.push_back(xnew);
59+
60+
// 3rd order, anharmonic, use dmove1 & dmove
61+
x=0; y=0; f=-1;
62+
ls.line_search(1,x,y,f,xnew,1e-10);
63+
xnew_arr.push_back(xnew);
64+
65+
x=1; y=2; f=3.5;
66+
ls.line_search(0,x,y,f,xnew,1e-10);
67+
xnew_arr.push_back(xnew);
68+
69+
// 3rd order, anharmonic, use dmove2 & dmoveh
70+
x=0; y=0; f=4.5;
71+
ls.line_search(1,x,y,f,xnew,1e-10);
72+
xnew_arr.push_back(xnew);
4573

74+
x=1; y=2; f=-1;
75+
ls.line_search(0,x,y,f,xnew,1e-10);
76+
xnew_arr.push_back(xnew);
4677

47-
class LineSearchTest : public ::testing::TestWithParam<LineSearchPrepare>{};
78+
// 3rd order, anharmonic, dmoveh > 4
79+
x=0; y=0; f=-20;
80+
ls.line_search(1,x,y,f,xnew,1e-10);
81+
xnew_arr.push_back(xnew);
4882

49-
TEST_P(LineSearchTest,LineSearch)
83+
x=1; y=-5; f=-18;
84+
ls.line_search(0,x,y,f,xnew,1e-10);
85+
xnew_arr.push_back(xnew);
86+
}
87+
};
88+
89+
TEST_F(Test_TO, line_search)
5090
{
51-
LineSearchPrepare lsp = GetParam();
52-
Line_Search ls;
53-
EXPECT_EQ(ls.ls_step,0);
54-
bool test_conv = ls.line_search(lsp.restart,lsp.x,lsp.y,lsp.f,lsp.xnew,lsp.conv_thr);
55-
EXPECT_EQ(ls.ls_step,1);
56-
while (!test_conv)
91+
std::vector<double> xnew_arr_ref =
92+
{1,0.5,1,4,1,4,1,0.1180828963,1,0.8181818182,1,4};
93+
94+
for(int i=0;i<xnew_arr.size();i++)
5795
{
58-
test_conv = ls.line_search(false,lsp.x,lsp.y,lsp.f,lsp.xnew,lsp.conv_thr);
96+
EXPECT_NEAR(xnew_arr[i],xnew_arr_ref[i],1e-8);
5997
}
60-
EXPECT_GE(ls.ls_step,3);
61-
EXPECT_TRUE(test_conv);
6298
}
6399

64-
INSTANTIATE_TEST_SUITE_P(LineSearch,LineSearchTest,::testing::Values(
65-
// restart, x, y, f, xnew, conv_thr
66-
LineSearchPrepare(true,5.0,100.1,-10.0,0.0,1e-6)
67-
));
100+
class Test_LS : public testing::Test
101+
{
102+
protected:
103+
std::vector<double> xnew_arr;
104+
105+
void SetUp()
106+
{
107+
Line_Search ls;
108+
109+
double x=0,xnew=0;
110+
111+
// here are 150 random numbers as values of x,y,f
112+
// try to catch as many cases as possible
113+
std::vector<double> rand=
114+
{
115+
7.84824536,9.64663789,3.02069873,
116+
11.26507127,0.57183349,13.78382791,
117+
11.35073038,13.93731114,17.79133148,
118+
8.21541546,14.93677637,18.98757025,
119+
3.9381696,16.94282377,15.3640563,
120+
7.42831671,17.3805471,8.80330788,
121+
0.99259854,6.1560873,2.81682224,
122+
8.53026114,6.09436657,11.7394269,
123+
12.90193598,5.89789934,8.64853882,
124+
4.73933146,3.36094493,14.25637065,
125+
14.60564235,13.66375342,4.4094395,
126+
13.24080597,6.94100136,6.74562183,
127+
4.93984237,3.28646424,12.95371468,
128+
2.28366581,14.10115138,17.89334777,
129+
17.85486276,11.66796459,3.83858509,
130+
17.67856229,14.09360901,6.32802805,
131+
9.81252062,13.00909712,18.59789075,
132+
2.20899305,17.91715464,17.65862859,
133+
0.49912591,15.47188694,3.74374207,
134+
14.36382216,3.95503455,9.02487212,
135+
4.93351326,8.21473905,7.10799664,
136+
6.13204507,1.85529643,5.98573821,
137+
8.77475213,4.84374422,8.91737412,
138+
12.26443836,16.89717106,19.53642954,
139+
4.69604119,12.24565644,14.7098488,
140+
12.98241239,8.19982363,19.52610301,
141+
6.18040644,9.14778002,0.3568356,
142+
7.09769476,19.18409854,15.39537538,
143+
8.41229496,12.21441438,17.09138446,
144+
10.10983535,10.93238563,13.22155721,
145+
8.86644395,12.70139813,10.73434842,
146+
11.82537801,2.14964817,14.59896678,
147+
15.41267417,2.81137988,3.71229432,
148+
1.83056217,7.41129002,1.37004508,
149+
17.76889137,1.89905894,14.16419654,
150+
4.33564906,9.08449031,3.62527869,
151+
12.79424987,18.55812915,3.74804462,
152+
2.75107002,2.60441846,7.82141378,
153+
4.16003289,9.68131555,12.5484824,
154+
9.68999296,9.7243542,18.39408735,
155+
19.00070811,11.87098669,3.78418022,
156+
6.1277818,10.78841329,5.24128999,
157+
16.13326588,3.58595401,5.21059945,
158+
19.06084549,0.5703917,9.66890009,
159+
0.73642789,2.74387619,5.27688081,
160+
14.07082298,9.87873532,10.62883574,
161+
10.89594589,5.06857927,9.08512489,
162+
8.75285552,12.71893686,9.90408215,
163+
2.96173193,0.8821268,2.75896956,
164+
18.05164767,17.20957441,8.86361389
165+
};
68166

69-
#undef private
167+
int ind=0;
168+
for(int i=0;i<50;i++)
169+
{
170+
bool restart = false;
171+
if(i%10==0) restart = true;
172+
173+
ls.line_search(restart,rand[ind],rand[ind+1],rand[ind+2],xnew,1e-10);
174+
ind = ind+3;
175+
176+
xnew_arr.push_back(xnew);
177+
}
178+
}
179+
};
180+
181+
TEST_F(Test_LS, line_search)
182+
{
183+
std::vector<double> xnew_arr_ref = {
184+
8.84824536,11.84824536,11.5220486,1.94478562,-4.61632212,14.40861093,-11.8788378,23.60558634,21.64528566,-11.58587758,
185+
15.60564235,18.60564235,-11.66208483,-3.02868731,22.1076108,17.32596135,-5.91956272,-12.99806209,0.03909330126,42.09321466,
186+
5.93351326,8.93351326,14.06016625,19.24381082,-10.44075315,29.55515479,-7.42360546,8.9322714,11.04149536,13.50491613,
187+
9.86644395,12.86644395,16.63592206,-25.33366183,49.64554977,-0.2852472955,29.71145149,-17.33528968,6.97795863,20.7499131,
188+
20.00070811,23.00070811,36.14423404,24.91600471,-35.91240731,40.73961316,4.54619171,4.46667478,-8.62051525,48.23147915
189+
};
190+
191+
for(int i=0;i<xnew_arr.size();i++)
192+
{
193+
EXPECT_NEAR(xnew_arr[i],xnew_arr_ref[i],1e-8);
194+
}
195+
}

0 commit comments

Comments
 (0)