@@ -24,6 +24,8 @@ limitations under the License.
2424 typedef float VALUETYPE;
2525#endif
2626
27+ typedef double compute_t ;
28+
2729typedef unsigned long long int_64;
2830
2931#define cudaErrcheck (res ) { cudaAssert ((res), __FILE__, __LINE__); }
@@ -132,12 +134,12 @@ __global__ void format_nlist_fill_a_se_a(const VALUETYPE * coord,
132134
133135 int_64 * key_in = key + idx * MAGIC_NUMBER;
134136
135- VALUETYPE diff[3 ];
137+ compute_t diff[3 ];
136138 const int & j_idx = nei_idx[idy];
137139 for (int dd = 0 ; dd < 3 ; dd++) {
138140 diff[dd] = coord[j_idx * 3 + dd] - coord[idx * 3 + dd];
139141 }
140- VALUETYPE rr = sqrt (dev_dot (diff, diff));
142+ compute_t rr = sqrt (dev_dot (diff, diff));
141143 if (rr <= rcut) {
142144 key_in[idy] = type[j_idx] * 1E15 + (int_64)(rr * 1.0E13 ) / 100000 * 100000 + j_idx;
143145 }
@@ -179,18 +181,19 @@ __global__ void format_nlist_fill_b_se_a(int * nlist,
179181}
180182// it's ok!
181183
182- __global__ void compute_descriptor_se_a (VALUETYPE* descript,
184+ template <typename FPTYPE>
185+ __global__ void compute_descriptor_se_a (FPTYPE* descript,
183186 const int ndescrpt,
184- VALUETYPE * descript_deriv,
187+ FPTYPE * descript_deriv,
185188 const int descript_deriv_size,
186- VALUETYPE * rij,
189+ FPTYPE * rij,
187190 const int rij_size,
188191 const int * type,
189- const VALUETYPE * avg,
190- const VALUETYPE * std,
192+ const FPTYPE * avg,
193+ const FPTYPE * std,
191194 int * nlist,
192195 const int nlist_size,
193- const VALUETYPE * coord,
196+ const FPTYPE * coord,
194197 const float rmin,
195198 const float rmax,
196199 const int sec_a_size)
@@ -203,24 +206,24 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
203206 if (idy >= sec_a_size) {return ;}
204207
205208 // else {return;}
206- VALUETYPE * row_descript = descript + idx * ndescrpt;
207- VALUETYPE * row_descript_deriv = descript_deriv + idx * descript_deriv_size;
208- VALUETYPE * row_rij = rij + idx * rij_size;
209+ FPTYPE * row_descript = descript + idx * ndescrpt;
210+ FPTYPE * row_descript_deriv = descript_deriv + idx * descript_deriv_size;
211+ FPTYPE * row_rij = rij + idx * rij_size;
209212 int * row_nlist = nlist + idx * nlist_size;
210213
211214 if (row_nlist[idy] >= 0 ) {
212215 const int & j_idx = row_nlist[idy];
213216 for (int kk = 0 ; kk < 3 ; kk++) {
214217 row_rij[idy * 3 + kk] = coord[j_idx * 3 + kk] - coord[idx * 3 + kk];
215218 }
216- const VALUETYPE * rr = &row_rij[idy * 3 + 0 ];
217- VALUETYPE nr2 = dev_dot (rr, rr);
218- VALUETYPE inr = 1 ./sqrt (nr2);
219- VALUETYPE nr = nr2 * inr;
220- VALUETYPE inr2 = inr * inr;
221- VALUETYPE inr4 = inr2 * inr2;
222- VALUETYPE inr3 = inr4 * nr;
223- VALUETYPE sw, dsw;
219+ const FPTYPE * rr = &row_rij[idy * 3 + 0 ];
220+ FPTYPE nr2 = dev_dot (rr, rr);
221+ FPTYPE inr = 1 ./sqrt (nr2);
222+ FPTYPE nr = nr2 * inr;
223+ FPTYPE inr2 = inr * inr;
224+ FPTYPE inr4 = inr2 * inr2;
225+ FPTYPE inr3 = inr4 * nr;
226+ FPTYPE sw, dsw;
224227 spline5_switch (sw, dsw, nr, rmin, rmax);
225228 row_descript[idx_value + 0 ] = (1 ./nr) ;// * sw;
226229 row_descript[idx_value + 1 ] = (rr[0 ] / nr2) ;// * sw;
0 commit comments