Skip to content

Commit a4f343e

Browse files
authored
Merge pull request #773 from pxlxingliang/develop
Modify: modify Gram schmidt to be CGS3 algorithm and add UT
2 parents ac9e050 + 200daf0 commit a4f343e

File tree

4 files changed

+157
-7
lines changed

4 files changed

+157
-7
lines changed

source/module_base/gram_schmidt_orth-inl.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,21 @@ std::vector<std::vector<Func_Type>> Gram_Schmidt_Orth<Func_Type,R_Type>::cal_ort
4343

4444
for( size_t if1=0; if1!=func.size(); ++if1 )
4545
{
46+
//use CGS2 algorithm to do twice orthogonalization
47+
//DOI 10.1007/s00211-005-0615-4
4648
std::vector<Func_Type> func_try = func[if1];
47-
for( size_t if_minus=0; if_minus!=func_new.size(); ++if_minus )
49+
for(int niter=0;niter<3;niter++)
4850
{
49-
// (hn,ei)
50-
const std::vector<Func_Type> && mul_func = Mathzone::Pointwise_Product( func[if1], func_new[if_minus] );
51-
const Func_Type in_product = cal_norm(mul_func);
51+
std::vector<Func_Type> func_tmp = func_try;
52+
for( size_t if_minus=0; if_minus!=func_new.size(); ++if_minus )
53+
{
54+
// (hn,ei)
55+
const std::vector<Func_Type> && mul_func = Mathzone::Pointwise_Product( func_tmp, func_new[if_minus] );
56+
const Func_Type in_product = cal_norm(mul_func);
5257

53-
// hn - (hn,ei)ei
54-
BlasConnector::axpy( mul_func.size(), -in_product, ModuleBase::GlobalFunc::VECTOR_TO_PTR(func_new[if_minus]), 1, ModuleBase::GlobalFunc::VECTOR_TO_PTR(func_try), 1);
58+
// hn - (hn,ei)ei
59+
BlasConnector::axpy( mul_func.size(), -in_product, ModuleBase::GlobalFunc::VECTOR_TO_PTR(func_new[if_minus]), 1, ModuleBase::GlobalFunc::VECTOR_TO_PTR(func_try), 1);
60+
}
5561
}
5662

5763
// ||gn||

source/module_base/gram_schmidt_orth.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#define GRAM_SCHMIDT_ORTH_H
88

99
#include<limits>
10-
10+
#include<vector>
1111
namespace ModuleBase
1212
{
1313

source/module_base/test/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ AddTest(
7777
LIBS ${math_libs}
7878
SOURCES math_ylmreal_test.cpp ../math_ylmreal.cpp ../ylm.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp ../vector3.h
7979
)
80+
AddTest(
81+
TARGET base_gram_schmidt_orth
82+
LIBS ${math_libs}
83+
SOURCES gram_schmidt_orth_test.cpp ../gram_schmidt_orth.h ../gram_schmidt_orth-inl.h ../global_function.h ../mathzone.h ../math_integral.cpp
84+
)
8085
AddTest(
8186
TARGET base_mathzone_add1
8287
LIBS ${math_libs}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#include"../gram_schmidt_orth.h"
2+
#include"../gram_schmidt_orth-inl.h"
3+
#include"gtest/gtest.h"
4+
5+
6+
#define DOUBLETHRESHOLD 1e-8
7+
8+
9+
/************************************************
10+
* unit test of class Gram_Schmidt_Orth
11+
***********************************************/
12+
13+
/**
14+
* Based on an linearly independent, but not orthonormal,
15+
* set of functions x:{x1,x2,x3,...}, we can construct an
16+
* orthonormal set X:{X1, X2, X3, ...} by using Gram-Schmidt
17+
* orthogonalization.
18+
* The new set X should has below properties:
19+
* 1. X1 = x1/||x1||
20+
* 2. <Xi,Xj> = 1 if (i == j) else 0
21+
*
22+
* Note:in this class, for coordinate of sphere, the inner product
23+
* of two radial function f(r) and g(r) equals the integral of r^2*f(r)*g(r)
24+
* $$ (f(r),g(r)) = {\int}r^2f(r)g(r)dr $$
25+
*
26+
*/
27+
28+
class GramSchmidtOrth
29+
{
30+
public:
31+
int nbasis;
32+
int ndim;
33+
double dr;
34+
std::vector<double> r2;
35+
double norm0;
36+
ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate coordinate;
37+
std::vector<double> rab;
38+
std::vector<std::vector<double>> basis;
39+
40+
GramSchmidtOrth(int nbasis, int ndim, double dr,
41+
ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate coordinate):
42+
nbasis(nbasis),ndim(ndim),dr(dr),coordinate(coordinate)
43+
{
44+
basis.resize(nbasis,std::vector<double>(ndim));
45+
rab.resize(ndim,dr);
46+
r2.resize(ndim,1.0);
47+
48+
norm0 = sqrt(1.0/3.0 * pow(dr*(static_cast<double>(ndim-1)),3.0));
49+
if (ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate::Sphere == this->coordinate)
50+
{
51+
for(int i=0;i<ndim;++i) {r2[i] = dr*i*dr*i;}
52+
norm0 = sqrt(1.0/5.0 * pow(dr*(static_cast<double>(ndim-1)),5.0));
53+
}
54+
55+
//build the function basis
56+
for(int i=0;i<nbasis;++i)
57+
{
58+
for(int j=0;j<ndim;++j)
59+
{
60+
//function: f_i(x) = x^(i+1)
61+
basis[i][j] = pow(static_cast<double>(j) * dr, static_cast<double>(i+1));
62+
}
63+
}
64+
}
65+
66+
//calculate the inner product of two vector
67+
double inner_product(std::vector<double> a, std::vector<double> b)
68+
{
69+
double ip;
70+
std::vector<double> mul_func = ModuleBase::Mathzone::Pointwise_Product(a,b);
71+
std::vector<double> mul_func1 = ModuleBase::Mathzone::Pointwise_Product(mul_func,r2);
72+
ModuleBase::Integral::Simpson_Integral(mul_func1.size(),ModuleBase::GlobalFunc::VECTOR_TO_PTR(mul_func1),ModuleBase::GlobalFunc::VECTOR_TO_PTR(rab),ip);
73+
return ip;
74+
}
75+
76+
};
77+
78+
class GramSchmidtOrthTest : public ::testing::TestWithParam<GramSchmidtOrth> {};
79+
80+
81+
TEST_P(GramSchmidtOrthTest,CalOrth)
82+
{
83+
GramSchmidtOrth gsot = GetParam();
84+
ModuleBase::Gram_Schmidt_Orth<double,double> gso_sphere(gsot.rab,gsot.coordinate);
85+
std::vector<std::vector<double>> old_basis = gsot.basis;
86+
std::vector<std::vector<double>> new_basis = gso_sphere.cal_orth(old_basis);
87+
88+
//==========================================================
89+
// VERIFY X0=x0/|x0|
90+
// the integral of old_basis[0] = {\int}_{0}^{dr*(ndim-1)} r^2*r*r dr
91+
// =1/5*r^5|_{0}^{dr*(ndim-1)}
92+
//==========================================================
93+
for(int i=0;i<gsot.ndim;i++)
94+
{
95+
EXPECT_NEAR(old_basis[0][i]/gsot.norm0,new_basis[0][i],DOUBLETHRESHOLD) << "the first basis is wrong";
96+
}
97+
98+
//==========================================================
99+
// VERIFY <Xi,Xj> = 0 for i!=j
100+
//==========================================================
101+
int niter = 1;
102+
int maxiter = 1;
103+
bool pass = false;
104+
double maxip;
105+
106+
//do iteration.
107+
while (true)
108+
{
109+
int nbasis = new_basis.size();
110+
maxip = abs(gsot.inner_product(new_basis[nbasis-1],new_basis[nbasis-2]));
111+
for(int i=0;i<nbasis-1;++i)
112+
{
113+
for(int j=i+1;j<nbasis;++j)
114+
{
115+
double ip = gsot.inner_product(new_basis[i],new_basis[j]);
116+
//std::cout << "i=" << i << ", j=" << j << ": " << ip << std::endl;
117+
if(abs(ip) > maxip) {maxip = abs(ip);}
118+
}
119+
}
120+
if (maxip < DOUBLETHRESHOLD) {pass = true; break;};
121+
if (niter >= maxiter) {break;}
122+
123+
niter += 1;
124+
old_basis = gso_sphere.cal_orth(new_basis); new_basis = old_basis;
125+
}
126+
127+
//std::cout << "nbasis=" << gsot.nbasis << "niter=" << niter << " max_inner_product=" << std::setprecision(15) << maxip << std::endl;
128+
EXPECT_TRUE(pass) << "nbasis=" << gsot.nbasis << "niter=" << niter << " max_inner_product=" << std::setprecision(15) << maxip;
129+
}
130+
131+
INSTANTIATE_TEST_SUITE_P(VerifyOrth,GramSchmidtOrthTest,::testing::Values(
132+
GramSchmidtOrth(10,101,0.1,ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate::Sphere),
133+
GramSchmidtOrth(20,1001,0.01,ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate::Sphere),
134+
GramSchmidtOrth(50,10001,0.001,ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate::Sphere),
135+
GramSchmidtOrth(10,10001,0.001,ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate::Cartesian),
136+
GramSchmidtOrth(20,1001,0.01,ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate::Cartesian),
137+
GramSchmidtOrth(50,101,0.1,ModuleBase::Gram_Schmidt_Orth<double,double>::Coordinate::Cartesian)
138+
));
139+

0 commit comments

Comments
 (0)