@@ -119,15 +119,34 @@ class DescrptSeAOp : public OpKernel {
119119
120120 int nei_mode = 0 ;
121121 if (mesh_tensor.shape ().dim_size (0 ) == 16 ) {
122+ // lammps neighbor list
122123 nei_mode = 3 ;
123124 }
124125 else if (mesh_tensor.shape ().dim_size (0 ) == 12 ) {
126+ // user provided extended mesh
125127 nei_mode = 2 ;
126128 }
127129 else if (mesh_tensor.shape ().dim_size (0 ) == 6 ) {
130+ // manual copied pbc
128131 assert (nloc == nall);
129132 nei_mode = 1 ;
130133 }
134+ else if (mesh_tensor.shape ().dim_size (0 ) == 0 ) {
135+ // no pbc
136+ nei_mode = -1 ;
137+ }
138+ else {
139+ throw runtime_error (" invalid mesh tensor" );
140+ }
141+ bool b_pbc = true ;
142+ // if region is given extended, do not use pbc
143+ if (nei_mode >= 1 || nei_mode == -1 ) {
144+ b_pbc = false ;
145+ }
146+ bool b_norm_atom = false ;
147+ if (nei_mode == 1 ){
148+ b_norm_atom = true ;
149+ }
131150
132151 // Create an output tensor
133152 TensorShape descrpt_shape ;
@@ -196,7 +215,7 @@ class DescrptSeAOp : public OpKernel {
196215 for (int dd = 0 ; dd < 3 ; ++dd){
197216 d_coord3[ii*3 +dd] = coord (kk, ii*3 +dd);
198217 }
199- if (nei_mode <= 1 ){
218+ if (b_norm_atom ){
200219 compute_t inter[3 ];
201220 region.phys2Inter (inter, &d_coord3[3 *ii]);
202221 for (int dd = 0 ; dd < 3 ; ++dd){
@@ -259,14 +278,11 @@ class DescrptSeAOp : public OpKernel {
259278 }
260279 ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, nloc, rcut_a, rcut_r, nat_stt, ncell, ext_stt, ext_end, region, ncell);
261280 }
262- else {
263- build_nlist (d_nlist_a, d_nlist_r, rcut_a, rcut_r, d_coord3, region);
281+ else if (nei_mode == - 1 ) {
282+ :: build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL );
264283 }
265-
266- bool b_pbc = true ;
267- // if region is given extended, do not use pbc
268- if (nei_mode >= 1 ) {
269- b_pbc = false ;
284+ else {
285+ throw runtime_error (" unknow neighbor mode" );
270286 }
271287
272288 // loop over atoms, compute descriptors for each atom
@@ -351,48 +367,6 @@ class DescrptSeAOp : public OpKernel {
351367 sec[ii] = sec[ii-1 ] + n_sel[ii-1 ];
352368 }
353369 }
354- void
355- build_nlist (vector<vector<int > > & nlist0,
356- vector<vector<int > > & nlist1,
357- const compute_t & rc0_,
358- const compute_t & rc1_,
359- const vector<compute_t > & posi3,
360- const SimulationRegion<compute_t > & region) const {
361- compute_t rc0 (rc0_);
362- compute_t rc1 (rc1_);
363- assert (rc0 <= rc1);
364- compute_t rc02 = rc0 * rc0;
365- // negative rc0 means not applying rc0
366- if (rc0 < 0 ) rc02 = 0 ;
367- compute_t rc12 = rc1 * rc1;
368-
369- unsigned natoms = posi3.size ()/3 ;
370- nlist0.clear ();
371- nlist1.clear ();
372- nlist0.resize (natoms);
373- nlist1.resize (natoms);
374- for (unsigned ii = 0 ; ii < natoms; ++ii){
375- nlist0[ii].reserve (60 );
376- nlist1[ii].reserve (60 );
377- }
378- for (unsigned ii = 0 ; ii < natoms; ++ii){
379- for (unsigned jj = ii+1 ; jj < natoms; ++jj){
380- compute_t diff[3 ];
381- region.diffNearestNeighbor (posi3[jj*3 +0 ], posi3[jj*3 +1 ], posi3[jj*3 +2 ],
382- posi3[ii*3 +0 ], posi3[ii*3 +1 ], posi3[ii*3 +2 ],
383- diff[0 ], diff[1 ], diff[2 ]);
384- compute_t r2 = MathUtilities::dot<compute_t > (diff, diff);
385- if (r2 < rc02) {
386- nlist0[ii].push_back (jj);
387- nlist0[jj].push_back (ii);
388- }
389- else if (r2 < rc12) {
390- nlist1[ii].push_back (jj);
391- nlist1[jj].push_back (ii);
392- }
393- }
394- }
395- }
396370};
397371
398372REGISTER_KERNEL_BUILDER (Name(" DescrptSeA" ).Device(DEVICE_CPU), DescrptSeAOp);
0 commit comments