@@ -577,6 +577,15 @@ class ProdEnvMatAOp : public OpKernel {
577577 mesh_tensor.flat <int >().data (), mesh_tensor_size, nloc, nei_mode,
578578 rcut_r, max_cpy_trial, max_nnei_trial);
579579
580+ // max_nbor_size may be changed after _prepare_coord_nlist_gpu
581+ // So we need to update the uint64_temp tensor if necessary
582+ if (uint64_temp.NumElements () < int_64 (nloc) * max_nbor_size * 2 ) {
583+ TensorShape uint64_shape;
584+ uint64_shape.AddDim (int_64 (nloc) * max_nbor_size * 2 );
585+ OP_REQUIRES_OK (context, context->allocate_temp (
586+ DT_UINT64, uint64_shape, &uint64_temp));
587+ array_longlong = uint64_temp.flat <unsigned long long >().data ();
588+ }
580589 // launch the gpu(nv) compute function
581590 deepmd::prod_env_mat_a_gpu (em, em_deriv, rij, nlist, coord, type,
582591 gpu_inlist, array_int, array_longlong,
@@ -875,6 +884,16 @@ class ProdEnvMatROp : public OpKernel {
875884 mesh_tensor.flat <int >().data (), mesh_tensor_size, nloc, nei_mode,
876885 rcut, max_cpy_trial, max_nnei_trial);
877886
887+ // max_nbor_size may be changed after _prepare_coord_nlist_gpu
888+ // So we need to update the uint64_temp tensor if necessary
889+ if (uint64_temp.NumElements () < int_64 (nloc) * max_nbor_size * 2 ) {
890+ TensorShape uint64_shape;
891+ uint64_shape.AddDim (int_64 (nloc) * max_nbor_size * 2 );
892+ OP_REQUIRES_OK (context, context->allocate_temp (
893+ DT_UINT64, uint64_shape, &uint64_temp));
894+ array_longlong = uint64_temp.flat <unsigned long long >().data ();
895+ }
896+
878897 // launch the gpu(nv) compute function
879898 deepmd::prod_env_mat_r_gpu (em, em_deriv, rij, nlist, coord, type,
880899 gpu_inlist, array_int, array_longlong,
@@ -1221,6 +1240,16 @@ class ProdEnvMatAMixOp : public OpKernel {
12211240 mesh_tensor.flat <int >().data (), mesh_tensor_size, nloc, nei_mode,
12221241 rcut_r, max_cpy_trial, max_nnei_trial);
12231242
1243+ // max_nbor_size may be changed after _prepare_coord_nlist_gpu
1244+ // So we need to update the uint64_temp tensor if necessary
1245+ if (uint64_temp.NumElements () < int_64 (nloc) * max_nbor_size * 2 ) {
1246+ TensorShape uint64_shape;
1247+ uint64_shape.AddDim (int_64 (nloc) * max_nbor_size * 2 );
1248+ OP_REQUIRES_OK (context, context->allocate_temp (
1249+ DT_UINT64, uint64_shape, &uint64_temp));
1250+ array_longlong = uint64_temp.flat <unsigned long long >().data ();
1251+ }
1252+
12241253 // launch the gpu(nv) compute function
12251254 deepmd::prod_env_mat_a_gpu (em, em_deriv, rij, nlist, coord, type,
12261255 gpu_inlist, array_int, array_longlong,
0 commit comments