Skip to content

Commit 143afce

Browse files
LuLu
authored andcommitted
fix bug of error compilation when using float precision
1 parent b00e29d commit 143afce

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

source/op/cuda/descrpt_se_a.cu

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ limitations under the License.
2424
typedef float VALUETYPE;
2525
#endif
2626

27-
typedef double compute_t;
28-
2927
typedef unsigned long long int_64;
3028

3129
#define cudaErrcheck(res) { cudaAssert((res), __FILE__, __LINE__); }
@@ -77,19 +75,20 @@ __device__ inline T dev_dot(T * arr1, T * arr2) {
7775
return arr1[0] * arr2[0] + arr1[1] * arr2[1] + arr1[2] * arr2[2];
7876
}
7977

80-
__device__ inline void spline5_switch(compute_t & vv,
81-
compute_t & dd,
82-
compute_t & xx,
83-
const compute_t & rmin,
84-
const compute_t & rmax)
78+
template<typename FPTYPE>
79+
__device__ inline void spline5_switch(FPTYPE & vv,
80+
FPTYPE & dd,
81+
FPTYPE & xx,
82+
const float & rmin,
83+
const float & rmax)
8584
{
8685
if (xx < rmin) {
8786
dd = 0;
8887
vv = 1;
8988
}
9089
else if (xx < rmax) {
91-
compute_t uu = (xx - rmin) / (rmax - rmin) ;
92-
compute_t du = 1. / (rmax - rmin) ;
90+
FPTYPE uu = (xx - rmin) / (rmax - rmin) ;
91+
FPTYPE du = 1. / (rmax - rmin) ;
9392
vv = uu*uu*uu * (-6 * uu*uu + 15 * uu - 10) + 1;
9493
dd = ( 3 * uu*uu * (-6 * uu*uu + 15 * uu - 10) + uu*uu*uu * (-12 * uu + 15) ) * du;
9594
}
@@ -133,12 +132,12 @@ __global__ void format_nlist_fill_a_se_a(const VALUETYPE * coord,
133132

134133
int_64 * key_in = key + idx * MAGIC_NUMBER;
135134

136-
compute_t diff[3];
135+
VALUETYPE diff[3];
137136
const int & j_idx = nei_idx[idy];
138137
for (int dd = 0; dd < 3; dd++) {
139138
diff[dd] = coord[j_idx * 3 + dd] - coord[idx * 3 + dd];
140139
}
141-
compute_t rr = sqrt(dev_dot(diff, diff));
140+
VALUETYPE rr = sqrt(dev_dot(diff, diff));
142141
if (rr <= rcut) {
143142
key_in[idy] = type[j_idx] * 1E15+ (int_64)(rr * 1.0E13) / 100000 * 100000 + j_idx;
144143
}
@@ -192,8 +191,8 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
192191
int* nlist,
193192
const int nlist_size,
194193
const VALUETYPE* coord,
195-
const VALUETYPE rmin,
196-
const VALUETYPE rmax,
194+
const float rmin,
195+
const float rmax,
197196
const int sec_a_size)
198197
{
199198
// <<<nloc, sec_a.back()>>>
@@ -214,14 +213,14 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
214213
for (int kk = 0; kk < 3; kk++) {
215214
row_rij[idy * 3 + kk] = coord[j_idx * 3 + kk] - coord[idx * 3 + kk];
216215
}
217-
const compute_t * rr = &row_rij[idy * 3 + 0];
218-
compute_t nr2 = dev_dot(rr, rr);
219-
compute_t inr = 1./sqrt(nr2);
220-
compute_t nr = nr2 * inr;
221-
compute_t inr2 = inr * inr;
222-
compute_t inr4 = inr2 * inr2;
223-
compute_t inr3 = inr4 * nr;
224-
compute_t sw, dsw;
216+
const VALUETYPE * rr = &row_rij[idy * 3 + 0];
217+
VALUETYPE nr2 = dev_dot(rr, rr);
218+
VALUETYPE inr = 1./sqrt(nr2);
219+
VALUETYPE nr = nr2 * inr;
220+
VALUETYPE inr2 = inr * inr;
221+
VALUETYPE inr4 = inr2 * inr2;
222+
VALUETYPE inr3 = inr4 * nr;
223+
VALUETYPE sw, dsw;
225224
spline5_switch(sw, dsw, nr, rmin, rmax);
226225
row_descript[idx_value + 0] = (1./nr) ;//* sw;
227226
row_descript[idx_value + 1] = (rr[0] / nr2) ;//* sw;

0 commit comments

Comments
 (0)