You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: source/op/descrpt_norot.cc
+18-11Lines changed: 18 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -73,6 +73,7 @@ class DescrptNorotOp : public OpKernel {
73
73
nnei_r = sec_r.back();
74
74
nnei = nnei_a + nnei_r;
75
75
fill_nei_a = (rcut_a < 0);
76
+
count_nei_idx_overflow = 0;
76
77
}
77
78
78
79
voidCompute(OpKernelContext* context) override {
@@ -92,26 +93,29 @@ class DescrptNorotOp : public OpKernel {
92
93
OP_REQUIRES (context, (natoms_tensor.shape().dims() == 1), errors::InvalidArgument ("Dim of natoms should be 1"));
93
94
OP_REQUIRES (context, (box_tensor.shape().dims() == 2), errors::InvalidArgument ("Dim of box should be 2"));
94
95
OP_REQUIRES (context, (mesh_tensor.shape().dims() == 1), errors::InvalidArgument ("Dim of mesh should be 1"));
95
-
OP_REQUIRES (context, (avg_tensor.shape().dims() == 1), errors::InvalidArgument ("Dim of avg should be 1"));
96
-
OP_REQUIRES (context, (std_tensor.shape().dims() == 1), errors::InvalidArgument ("Dim of std should be 1"));
96
+
OP_REQUIRES (context, (avg_tensor.shape().dims() == 2), errors::InvalidArgument ("Dim of avg should be 2"));
97
+
OP_REQUIRES (context, (std_tensor.shape().dims() == 2), errors::InvalidArgument ("Dim of std should be 2"));
97
98
OP_REQUIRES (context, (fill_nei_a), errors::InvalidArgument ("Rotational free descriptor only support the case rcut_a < 0"));
98
99
OP_REQUIRES (context, (sec_r.back() == 0), errors::InvalidArgument ("Rotational free descriptor only support all-angular information: sel_r should be all zero."));
99
100
100
101
OP_REQUIRES (context, (natoms_tensor.shape().dim_size(0) >= 3), errors::InvalidArgument ("number of atoms should be larger than (or equal to) 3"));
101
102
auto natoms = natoms_tensor .flat<int>();
102
103
int nloc = natoms(0);
103
104
int nall = natoms(1);
105
+
int ntypes = natoms_tensor.shape().dim_size(0) - 2;
104
106
int nsamples = coord_tensor.shape().dim_size(0);
105
107
106
108
// check the sizes
107
109
OP_REQUIRES (context, (nsamples == type_tensor.shape().dim_size(0)), errors::InvalidArgument ("number of samples should match"));
108
110
OP_REQUIRES (context, (nsamples == box_tensor.shape().dim_size(0)), errors::InvalidArgument ("number of samples should match"));
109
-
OP_REQUIRES (context, (ndescrpt == avg_tensor.shape().dim_size(0)), errors::InvalidArgument ("number of avg should be ndescrpt"));
110
-
OP_REQUIRES (context, (ndescrpt == std_tensor.shape().dim_size(0)), errors::InvalidArgument ("number of std should be ndescrpt"));
111
+
OP_REQUIRES (context, (ntypes == avg_tensor.shape().dim_size(0)), errors::InvalidArgument ("number of avg should be ntype"));
112
+
OP_REQUIRES (context, (ntypes == std_tensor.shape().dim_size(0)), errors::InvalidArgument ("number of std should be ntype"));
111
113
112
114
OP_REQUIRES (context, (nall * 3 == coord_tensor.shape().dim_size(1)), errors::InvalidArgument ("number of atoms should match"));
113
115
OP_REQUIRES (context, (nall == type_tensor.shape().dim_size(1)), errors::InvalidArgument ("number of atoms should match"));
114
116
OP_REQUIRES (context, (9 == box_tensor.shape().dim_size(1)), errors::InvalidArgument ("number of box should be 9"));
117
+
OP_REQUIRES (context, (ndescrpt == avg_tensor.shape().dim_size(1)), errors::InvalidArgument ("number of avg should be ndescrpt"));
118
+
OP_REQUIRES (context, (ndescrpt == std_tensor.shape().dim_size(1)), errors::InvalidArgument ("number of std should be ndescrpt"));
115
119
116
120
int nei_mode = 0;
117
121
if (mesh_tensor.shape().dim_size(0) == 16) {
@@ -161,8 +165,8 @@ class DescrptNorotOp : public OpKernel {
161
165
auto type = type_tensor .matrix<int>();
162
166
auto box = box_tensor .matrix<VALUETYPE>();
163
167
auto mesh = mesh_tensor .flat<int>();
164
-
auto avg = avg_tensor .flat<VALUETYPE>();
165
-
auto std = std_tensor .flat<VALUETYPE>();
168
+
auto avg = avg_tensor .matrix<VALUETYPE>();
169
+
auto std = std_tensor .matrix<VALUETYPE>();
166
170
auto descrpt = descrpt_tensor ->matrix<VALUETYPE>();
167
171
auto descrpt_deriv = descrpt_deriv_tensor ->matrix<VALUETYPE>();
168
172
auto rij = rij_tensor ->matrix<VALUETYPE>();
@@ -174,7 +178,6 @@ class DescrptNorotOp : public OpKernel {
174
178
// if (type(0, ii) > max_type_v) max_type_v = type(0, ii);
175
179
// }
176
180
// int ntypes = max_type_v + 1;
177
-
int ntypes = natoms_tensor.shape().dim_size(0) - 2;
178
181
OP_REQUIRES (context, (ntypes == int(sel_a.size())), errors::InvalidArgument ("number of types should match the length of sel array"));
179
182
OP_REQUIRES (context, (ntypes == int(sel_r.size())), errors::InvalidArgument ("number of types should match the length of sel array"));
180
183
@@ -274,8 +277,11 @@ class DescrptNorotOp : public OpKernel {
274
277
int ret = -1;
275
278
if (fill_nei_a){
276
279
if ((ret = format_nlist_fill_a (fmt_nlist_a, fmt_nlist_r, d_coord3, ntypes, d_type, region, b_pbc, ii, d_nlist_a[ii], d_nlist_r[ii], rcut_r, sec_a, sec_r)) != -1){
277
-
cout << "Radial neighbor list length of type " << ret << " is not enough" << endl;
278
-
exit(1);
280
+
if (count_nei_idx_overflow == 0) {
281
+
cout << "WARNING: Radial neighbor list length of type " << ret << " is not enough" << endl;
282
+
flush(cout);
283
+
count_nei_idx_overflow ++;
284
+
}
279
285
}
280
286
}
281
287
@@ -306,10 +312,10 @@ class DescrptNorotOp : public OpKernel {
0 commit comments