Skip to content

Commit 4a7b8cb

Browse files
dzzz2001Fisherd99
authored andcommitted
Perf: optimize deepks functions cal_f_delta and cal_pdm (deepmodeling#5933)
* inline base_matrix function * add openmp to deepks_force and deepks_pdm * inline more functions in base_matrix * inline functions of intarray * initialize some variables * fix some format * fix format * fix a bug
1 parent ad431d2 commit 4a7b8cb

File tree

6 files changed

+528
-546
lines changed

6 files changed

+528
-546
lines changed

source/module_base/intarray.cpp

Lines changed: 21 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ IntArray::IntArray(const int d1,const int d2)
2222
bound3 = bound4 = bound5 = bound6 = 0;
2323
size = bound1 * bound2;
2424
ptr = new int[size];zero_out();
25-
assert( ptr != 0);
25+
assert( ptr != nullptr);
2626
++arrayCount;
2727
}
2828

@@ -36,7 +36,7 @@ IntArray::IntArray(const int d1,const int d2,const int d3)
3636
//set_new_handler(IntArrayAlloc);
3737
size = bound1 * bound2 * bound3 ; //* sizeof(float);
3838
ptr = new int[size];zero_out();
39-
assert(ptr != 0);
39+
assert(ptr != nullptr);
4040
++arrayCount;
4141
}
4242

@@ -51,7 +51,7 @@ IntArray::IntArray(const int d1,const int d2,const int d3,const int d4)
5151
//set_new_handler(IntArrayAlloc);
5252
size = bound1 * bound2 * bound3 * bound4 ; //* sizeof(float);
5353
ptr = new int[size];zero_out();
54-
assert(ptr != 0);
54+
assert(ptr != nullptr);
5555
++arrayCount;
5656
}
5757

@@ -67,7 +67,7 @@ IntArray::IntArray(const int d1,const int d2,const int d3,
6767
//set_new_handler(IntArrayAlloc);
6868
size = bound1 * bound2 * bound3 * bound4 * bound5;
6969
ptr = new int[size];zero_out();
70-
assert(ptr != 0);
70+
assert(ptr != nullptr);
7171
++arrayCount;
7272
}
7373

@@ -84,7 +84,7 @@ IntArray::IntArray(const int d1,const int d2,const int d3,
8484
//set_new_handler(IntArrayAlloc);
8585
size = bound1 * bound2 * bound3 * bound4 * bound5 * bound6;
8686
ptr = new int[size];zero_out();
87-
assert(ptr != 0);
87+
assert(ptr != nullptr);
8888
++arrayCount;
8989
}
9090

@@ -98,10 +98,10 @@ IntArray ::~IntArray()
9898

9999
void IntArray::freemem()
100100
{
101-
if(ptr!=NULL)
101+
if(ptr!= nullptr)
102102
{
103103
delete [] ptr;
104-
ptr = NULL;
104+
ptr = nullptr;
105105
}
106106
}
107107

@@ -111,7 +111,7 @@ void IntArray::create(const int d1,const int d2,const int d3,const int d4,const
111111
dim = 6;
112112
bound1 = d1;bound2 = d2;bound3 = d3;bound4 = d4;bound5 = d5;bound6 = d6;
113113
delete[] ptr; ptr = new int[size];
114-
assert(ptr != 0);zero_out();
114+
assert(ptr != nullptr);zero_out();
115115
}
116116

117117
void IntArray::create(const int d1,const int d2,const int d3,const int d4,const int d5)
@@ -120,7 +120,7 @@ void IntArray::create(const int d1,const int d2,const int d3,const int d4,const
120120
dim = 5;
121121
bound1 = d1;bound2 = d2;bound3 = d3;bound4 = d4;bound5 = d5;
122122
delete[] ptr; ptr = new int[size];
123-
assert(ptr != 0);zero_out();
123+
assert(ptr != nullptr);zero_out();
124124
}
125125

126126
void IntArray::create(const int d1,const int d2,const int d3,const int d4)
@@ -129,7 +129,7 @@ void IntArray::create(const int d1,const int d2,const int d3,const int d4)
129129
dim = 4;
130130
bound1 = d1;bound2 = d2;bound3 = d3;bound4 = d4;
131131
delete[] ptr; ptr = new int[size];
132-
assert(ptr != 0);zero_out();
132+
assert(ptr != nullptr);zero_out();
133133
}
134134

135135
void IntArray::create(const int d1,const int d2,const int d3)
@@ -138,7 +138,7 @@ void IntArray::create(const int d1,const int d2,const int d3)
138138
dim = 3;
139139
bound1 = d1;bound2 = d2;bound3 = d3;bound4 = 1;
140140
delete [] ptr;ptr = new int[size];
141-
assert(ptr != 0);zero_out();
141+
assert(ptr != nullptr);zero_out();
142142
}
143143

144144
void IntArray::create(const int d1, const int d2)
@@ -147,134 +147,22 @@ void IntArray::create(const int d1, const int d2)
147147
dim = 2;
148148
bound1 = d1;bound2 = d2;bound3 = bound4 = 1;
149149
delete[] ptr;ptr = new int[size];
150-
assert(ptr !=0 );zero_out();
151-
}
152-
153-
const IntArray &IntArray::operator=(const IntArray &right)
154-
{
155-
assert( this->size == right.size );
156-
for (int i = 0;i < size;i++) ptr[i] = right.ptr[i];
157-
return *this;// enables x = y = z;
158-
}
159-
160-
const IntArray &IntArray::operator=(const int &value)
161-
{
162-
for (int i = 0;i < size;i++) ptr[i] = value;
163-
return *this;// enables x = y = z;
164-
}
165-
166-
//********************************************************
167-
// overloaded subscript operator for const Int Array
168-
// const reference return create an cvakue
169-
//********************************************************
170-
const int &IntArray::operator()
171-
(const int ind1,const int ind2)const
172-
{
173-
assert( ind1 < bound1 );
174-
assert( ind2 < bound2 );
175-
return ptr[ ind1 * bound2 + ind2 ];
176-
}
177-
178-
const int &IntArray::operator()
179-
(const int ind1,const int ind2,const int ind3)const
180-
{
181-
assert( ind1 < bound1 );
182-
assert( ind2 < bound2 );
183-
assert( ind3 < bound3 );
184-
return ptr[ (ind1 * bound2 + ind2) * bound3 + ind3 ];
185-
}
186-
187-
const int &IntArray::operator()
188-
(const int ind1,const int ind2,const int ind3,const int ind4)const
189-
{
190-
assert( ind1 < bound1 );
191-
assert( ind2 < bound2 );
192-
assert( ind3 < bound3 );
193-
assert( ind4 < bound4 );
194-
return ptr[ ((ind1 * bound2 + ind2) * bound3 + ind3) * bound4 + ind4 ];
195-
}
196-
197-
const int &IntArray::operator()
198-
(const int ind1,const int ind2,const int ind3,const int ind4,const int ind5)const
199-
{
200-
assert( ind1 < bound1 );
201-
assert( ind2 < bound2 );
202-
assert( ind3 < bound3 );
203-
assert( ind4 < bound4 );
204-
assert( ind5 < bound5 );
205-
return ptr[ (((ind1 * bound2 + ind2) * bound3 + ind3) * bound4 + ind4) * bound5 + ind5 ];
206-
}
207-
208-
const int &IntArray::operator()
209-
(const int ind1,const int ind2,const int ind3,const int ind4,const int ind5,const int ind6)const
210-
{
211-
assert( ind1 < bound1 );
212-
assert( ind2 < bound2 );
213-
assert( ind3 < bound3 );
214-
assert( ind4 < bound4 );
215-
assert( ind5 < bound5 );
216-
assert( ind6 < bound6 );
217-
return ptr[ ((((ind1 * bound2 + ind2) * bound3 + ind3) * bound4 + ind4) * bound5 + ind5) * bound6 + ind6 ];
218-
}
219-
220-
//********************************************************
221-
// overloaded subscript operator for non-const Int Array
222-
// const reference return creates an lvakue
223-
//********************************************************
224-
int &IntArray::operator()(const int ind1,const int ind2)
225-
{
226-
assert( ind1 < bound1 );
227-
assert( ind2 < bound2 );
228-
return ptr[ind1 * bound2 + ind2];
229-
}
230-
231-
int &IntArray::operator()(const int ind1,const int ind2,const int ind3)
232-
{
233-
assert( ind1 < bound1 );
234-
assert( ind2 < bound2 );
235-
assert( ind3 < bound3 );
236-
return ptr[ (ind1 * bound2 + ind2) * bound3 + ind3 ];
237-
}
238-
239-
int &IntArray::operator()(const int ind1,const int ind2,const int ind3,const int ind4)
240-
{
241-
assert( ind1 < bound1 );
242-
assert( ind2 < bound2 );
243-
assert( ind3 < bound3 );
244-
assert( ind4 < bound4 );
245-
return ptr[ ((ind1 * bound2 + ind2) * bound3 + ind3) * bound4 + ind4 ];
246-
}
247-
248-
int &IntArray::operator()
249-
(const int ind1,const int ind2,const int ind3,const int ind4,const int ind5)
250-
{
251-
assert( ind1 < bound1 );
252-
assert( ind2 < bound2 );
253-
assert( ind3 < bound3 );
254-
assert( ind4 < bound4 );
255-
assert( ind5 < bound5 );
256-
return ptr[ (((ind1 * bound2 + ind2) * bound3 + ind3) * bound4 + ind4) * bound5 + ind5 ];
257-
}
258-
259-
int &IntArray::operator()
260-
(const int ind1,const int ind2,const int ind3,const int ind4,const int ind5,const int ind6)
261-
{
262-
assert( ind1 < bound1 );
263-
assert( ind2 < bound2 );
264-
assert( ind3 < bound3 );
265-
assert( ind4 < bound4 );
266-
assert( ind5 < bound5 );
267-
assert( ind6 < bound6 );
268-
return ptr[ ((((ind1 * bound2 + ind2) * bound3 + ind3) * bound4 + ind4) * bound5 + ind5) * bound6 + ind6 ];
150+
assert(ptr != nullptr );zero_out();
269151
}
270152

271153
//****************************
272154
// zeroes out the whole array
273155
//****************************
274-
void IntArray::zero_out(void)
156+
void IntArray::zero_out()
275157
{
276-
if (size <= 0) return;
277-
for (int i = 0;i < size; i++) ptr[i] = 0;
158+
if (size <= 0)
159+
{
160+
return;
161+
}
162+
for (int i = 0;i < size; i++)
163+
{
164+
ptr[i] = 0;
165+
}
278166
return;
279167
}
280168

source/module_base/intarray.h

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ class IntArray
5353
* @param right
5454
* @return const IntArray&
5555
*/
56-
const IntArray &operator=(const IntArray &right);
56+
const IntArray &operator=(const IntArray &right)
57+
{
58+
assert( this->size == right.size );
59+
for (int i = 0;i < size;i++) ptr[i] = right.ptr[i];
60+
return *this;// enables x = y = z;
61+
};
5762

5863
/**
5964
* @brief Equal all elements of an IntArray to an
@@ -62,7 +67,11 @@ class IntArray
6267
* @param right
6368
* @return const IntArray&
6469
*/
65-
const IntArray &operator=(const int &right);
70+
const IntArray &operator=(const int &right)
71+
{
72+
for (int i = 0;i < size;i++) ptr[i] = right;
73+
return *this;// enables x = y = z;
74+
};
6675

6776
/**
6877
* @brief Access elements by using operator "()"
@@ -71,11 +80,46 @@ class IntArray
7180
* @param d2
7281
* @return int&
7382
*/
74-
int &operator()(const int d1, const int d2);
75-
int &operator()(const int d1, const int d2, const int d3);
76-
int &operator()(const int d1, const int d2, const int d3, const int d4);
77-
int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5);
78-
int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6);
83+
int &operator()(const int d1, const int d2)
84+
{
85+
assert( d1 < bound1 );
86+
assert( d2 < bound2 );
87+
return ptr[ d1 * bound2 + d2 ];
88+
};
89+
int &operator()(const int d1, const int d2, const int d3)
90+
{
91+
assert( d1 < bound1 );
92+
assert( d2 < bound2 );
93+
assert( d3 < bound3 );
94+
return ptr[ (d1 * bound2 + d2) * bound3 + d3 ];
95+
};
96+
int &operator()(const int d1, const int d2, const int d3, const int d4)
97+
{
98+
assert( d1 < bound1 );
99+
assert( d2 < bound2 );
100+
assert( d3 < bound3 );
101+
assert( d4 < bound4 );
102+
return ptr[ ((d1 * bound2 + d2) * bound3 + d3) * bound4 + d4 ];
103+
};
104+
int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5)
105+
{
106+
assert( d1 < bound1 );
107+
assert( d2 < bound2 );
108+
assert( d3 < bound3 );
109+
assert( d4 < bound4 );
110+
assert( d5 < bound5 );
111+
return ptr[ (((d1 * bound2 + d2) * bound3 + d3) * bound4 + d4) * bound5 + d5 ];
112+
};
113+
int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6)
114+
{
115+
assert( d1 < bound1 );
116+
assert( d2 < bound2 );
117+
assert( d3 < bound3 );
118+
assert( d4 < bound4 );
119+
assert( d5 < bound5 );
120+
assert( d6 < bound6 );
121+
return ptr[ ((((d1 * bound2 + d2) * bound3 + d3) * bound4 + d4) * bound5 + d5) * bound6 + d6 ];
122+
};
79123

80124
/**
81125
* @brief Access elements by using "()" through pointer
@@ -85,11 +129,46 @@ class IntArray
85129
* @param d2
86130
* @return const int&
87131
*/
88-
const int &operator()(const int d1, const int d2) const;
89-
const int &operator()(const int d1, const int d2, const int d3) const;
90-
const int &operator()(const int d1, const int d2, const int d3, const int d4) const;
91-
const int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5) const;
92-
const int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6) const;
132+
const int &operator()(const int d1, const int d2) const
133+
{
134+
assert( d1 < bound1 );
135+
assert( d2 < bound2 );
136+
return ptr[ d1 * bound2 + d2 ];
137+
};
138+
const int &operator()(const int d1, const int d2, const int d3) const
139+
{
140+
assert( d1 < bound1 );
141+
assert( d2 < bound2 );
142+
assert( d3 < bound3 );
143+
return ptr[ (d1 * bound2 + d2) * bound3 + d3 ];
144+
};
145+
const int &operator()(const int d1, const int d2, const int d3, const int d4) const
146+
{
147+
assert( d1 < bound1 );
148+
assert( d2 < bound2 );
149+
assert( d3 < bound3 );
150+
assert( d4 < bound4 );
151+
return ptr[ ((d1 * bound2 + d2) * bound3 + d3) * bound4 + d4 ];
152+
};
153+
const int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5) const
154+
{
155+
assert( d1 < bound1 );
156+
assert( d2 < bound2 );
157+
assert( d3 < bound3 );
158+
assert( d4 < bound4 );
159+
assert( d5 < bound5 );
160+
return ptr[ (((d1 * bound2 + d2) * bound3 + d3) * bound4 + d4) * bound5 + d5 ];
161+
};
162+
const int &operator()(const int d1, const int d2, const int d3, const int d4, const int d5, const int d6) const
163+
{
164+
assert( d1 < bound1 );
165+
assert( d2 < bound2 );
166+
assert( d3 < bound3 );
167+
assert( d4 < bound4 );
168+
assert( d5 < bound5 );
169+
assert( d6 < bound6 );
170+
return ptr[ ((((d1 * bound2 + d2) * bound3 + d3) * bound4 + d4) * bound5 + d5) * bound6 + d6 ];
171+
};
93172

94173
/**
95174
* @brief Set all elements of an IntArray to zero

0 commit comments

Comments
 (0)