Skip to content

Commit 138ffef

Browse files
authored
Fix max nbor size related issues (#3157)
1 parent e5f9117 commit 138ffef

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

source/op/prod_env_mat_multi_device.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,15 @@ class ProdEnvMatAOp : public OpKernel {
577577
mesh_tensor.flat<int>().data(), mesh_tensor_size, nloc, nei_mode,
578578
rcut_r, max_cpy_trial, max_nnei_trial);
579579

580+
// max_nbor_size may be changed after _prepare_coord_nlist_gpu
581+
// So we need to update the uint64_temp tensor if necessary
582+
if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) {
583+
TensorShape uint64_shape;
584+
uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2);
585+
OP_REQUIRES_OK(context, context->allocate_temp(
586+
DT_UINT64, uint64_shape, &uint64_temp));
587+
array_longlong = uint64_temp.flat<unsigned long long>().data();
588+
}
580589
// launch the gpu(nv) compute function
581590
deepmd::prod_env_mat_a_gpu(em, em_deriv, rij, nlist, coord, type,
582591
gpu_inlist, array_int, array_longlong,
@@ -875,6 +884,16 @@ class ProdEnvMatROp : public OpKernel {
875884
mesh_tensor.flat<int>().data(), mesh_tensor_size, nloc, nei_mode,
876885
rcut, max_cpy_trial, max_nnei_trial);
877886

887+
// max_nbor_size may be changed after _prepare_coord_nlist_gpu
888+
// So we need to update the uint64_temp tensor if necessary
889+
if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) {
890+
TensorShape uint64_shape;
891+
uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2);
892+
OP_REQUIRES_OK(context, context->allocate_temp(
893+
DT_UINT64, uint64_shape, &uint64_temp));
894+
array_longlong = uint64_temp.flat<unsigned long long>().data();
895+
}
896+
878897
// launch the gpu(nv) compute function
879898
deepmd::prod_env_mat_r_gpu(em, em_deriv, rij, nlist, coord, type,
880899
gpu_inlist, array_int, array_longlong,
@@ -1221,6 +1240,16 @@ class ProdEnvMatAMixOp : public OpKernel {
12211240
mesh_tensor.flat<int>().data(), mesh_tensor_size, nloc, nei_mode,
12221241
rcut_r, max_cpy_trial, max_nnei_trial);
12231242

1243+
// max_nbor_size may be changed after _prepare_coord_nlist_gpu
1244+
// So we need to update the uint64_temp tensor if necessary
1245+
if (uint64_temp.NumElements() < int_64(nloc) * max_nbor_size * 2) {
1246+
TensorShape uint64_shape;
1247+
uint64_shape.AddDim(int_64(nloc) * max_nbor_size * 2);
1248+
OP_REQUIRES_OK(context, context->allocate_temp(
1249+
DT_UINT64, uint64_shape, &uint64_temp));
1250+
array_longlong = uint64_temp.flat<unsigned long long>().data();
1251+
}
1252+
12241253
// launch the gpu(nv) compute function
12251254
deepmd::prod_env_mat_a_gpu(em, em_deriv, rij, nlist, coord, type,
12261255
gpu_inlist, array_int, array_longlong,

0 commit comments

Comments
 (0)