@@ -24,8 +24,6 @@ limitations under the License.
2424 typedef float VALUETYPE;
2525#endif
2626
27- typedef double compute_t ;
28-
2927typedef unsigned long long int_64;
3028
3129#define cudaErrcheck (res ) { cudaAssert ((res), __FILE__, __LINE__); }
@@ -77,19 +75,20 @@ __device__ inline T dev_dot(T * arr1, T * arr2) {
7775 return arr1[0 ] * arr2[0 ] + arr1[1 ] * arr2[1 ] + arr1[2 ] * arr2[2 ];
7876}
7977
80- __device__ inline void spline5_switch (compute_t & vv,
81- compute_t & dd,
82- compute_t & xx,
83- const compute_t & rmin,
84- const compute_t & rmax)
78+ template <typename FPTYPE>
79+ __device__ inline void spline5_switch (FPTYPE & vv,
80+ FPTYPE & dd,
81+ FPTYPE & xx,
82+ const float & rmin,
83+ const float & rmax)
8584{
8685 if (xx < rmin) {
8786 dd = 0 ;
8887 vv = 1 ;
8988 }
9089 else if (xx < rmax) {
91- compute_t uu = (xx - rmin) / (rmax - rmin) ;
92- compute_t du = 1 . / (rmax - rmin) ;
90+ FPTYPE uu = (xx - rmin) / (rmax - rmin) ;
91+ FPTYPE du = 1 . / (rmax - rmin) ;
9392 vv = uu*uu*uu * (-6 * uu*uu + 15 * uu - 10 ) + 1 ;
9493 dd = ( 3 * uu*uu * (-6 * uu*uu + 15 * uu - 10 ) + uu*uu*uu * (-12 * uu + 15 ) ) * du;
9594 }
@@ -133,12 +132,12 @@ __global__ void format_nlist_fill_a_se_a(const VALUETYPE * coord,
133132
134133 int_64 * key_in = key + idx * MAGIC_NUMBER;
135134
136- compute_t diff[3 ];
135+ VALUETYPE diff[3 ];
137136 const int & j_idx = nei_idx[idy];
138137 for (int dd = 0 ; dd < 3 ; dd++) {
139138 diff[dd] = coord[j_idx * 3 + dd] - coord[idx * 3 + dd];
140139 }
141- compute_t rr = sqrt (dev_dot (diff, diff));
140+ VALUETYPE rr = sqrt (dev_dot (diff, diff));
142141 if (rr <= rcut) {
143142 key_in[idy] = type[j_idx] * 1E15 + (int_64)(rr * 1.0E13 ) / 100000 * 100000 + j_idx;
144143 }
@@ -192,8 +191,8 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
192191 int * nlist,
193192 const int nlist_size,
194193 const VALUETYPE* coord,
195- const VALUETYPE rmin,
196- const VALUETYPE rmax,
194+ const float rmin,
195+ const float rmax,
197196 const int sec_a_size)
198197{
199198 // <<<nloc, sec_a.back()>>>
@@ -214,14 +213,14 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
214213 for (int kk = 0 ; kk < 3 ; kk++) {
215214 row_rij[idy * 3 + kk] = coord[j_idx * 3 + kk] - coord[idx * 3 + kk];
216215 }
217- const compute_t * rr = &row_rij[idy * 3 + 0 ];
218- compute_t nr2 = dev_dot (rr, rr);
219- compute_t inr = 1 ./sqrt (nr2);
220- compute_t nr = nr2 * inr;
221- compute_t inr2 = inr * inr;
222- compute_t inr4 = inr2 * inr2;
223- compute_t inr3 = inr4 * nr;
224- compute_t sw, dsw;
216+ const VALUETYPE * rr = &row_rij[idy * 3 + 0 ];
217+ VALUETYPE nr2 = dev_dot (rr, rr);
218+ VALUETYPE inr = 1 ./sqrt (nr2);
219+ VALUETYPE nr = nr2 * inr;
220+ VALUETYPE inr2 = inr * inr;
221+ VALUETYPE inr4 = inr2 * inr2;
222+ VALUETYPE inr3 = inr4 * nr;
223+ VALUETYPE sw, dsw;
225224 spline5_switch (sw, dsw, nr, rmin, rmax);
226225 row_descript[idx_value + 0 ] = (1 ./nr) ;// * sw;
227226 row_descript[idx_value + 1 ] = (rr[0 ] / nr2) ;// * sw;
0 commit comments