diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 612f699ea4..0c5cde2f3b 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -102,6 +102,9 @@ void select_real_atoms_coord(std::vector& dcoord, const int& nall, const bool aparam_nall = false); +void select_real_atoms_sendlist(const deepmd::InputNlist& inlist, + const std::vector& fwd_map); + /** * @brief Apply the given map to a vector. * @param[out] out The output vector. diff --git a/source/api_cc/src/DeepPotPD.cc b/source/api_cc/src/DeepPotPD.cc index d81a63b131..f6d5c76057 100644 --- a/source/api_cc/src/DeepPotPD.cc +++ b/source/api_cc/src/DeepPotPD.cc @@ -393,6 +393,7 @@ void DeepPotPD::compute(ENERGYVTYPE& ener, auto sendlist_tensor = predictor_fl->GetInputHandle("send_list"); int nswap = lmp_list.nswap; + select_real_atoms_sendlist(lmp_list, fwd_map); sendproc_tensor->Reshape({nswap}); sendproc_tensor->CopyFromCpu(lmp_list.sendproc); diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 0f3a72b87f..070c69256d 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -174,6 +174,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, nlist_data.padding(); if (do_message_passing) { int nswap = lmp_list.nswap; + select_real_atoms_sendlist(lmp_list, fwd_map); torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); torch::Tensor recvproc_tensor = diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 8ccf2fd383..c3b41d8b7d 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -182,6 +182,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, nlist_data.padding(); if (do_message_passing) { int nswap = lmp_list.nswap; + select_real_atoms_sendlist(lmp_list, fwd_map); torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); torch::Tensor recvproc_tensor = diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index eace577f89..ff2f43a378 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -232,6 +232,31 @@ template void deepmd::select_real_atoms_coord( const int& nall, const bool aparam_nall); +void deepmd::select_real_atoms_sendlist(const deepmd::InputNlist& inlist, + const std::vector& fwd_map) { + int nswap = inlist.nswap; + std::vector> sendlist_new; + sendlist_new.resize(nswap); + // select real atoms in sendlist + for (int s = 0; s < nswap; ++s) { + int cnt = 0; + sendlist_new[s].reserve(inlist.sendnum[s]); + for (int k = 0; k < inlist.sendnum[s]; ++k) { + const int old_idx = inlist.sendlist[s][k]; + int mapped = (old_idx >= 0 && old_idx < (int)fwd_map.size()) + ? fwd_map[old_idx] + : -1; + if (mapped >= 0) { + sendlist_new[s].push_back(mapped); + ++cnt; + } + } + std::memcpy(inlist.sendlist[s], sendlist_new[s].data(), cnt * sizeof(int)); + inlist.sendnum[s] = cnt; + inlist.recvnum[s] = cnt; + } +} + void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist, const int natoms) { int inum = natoms >= 0 ? natoms : inlist.inum; diff --git a/source/lmp/tests/test_lammps_dpa_pt.py b/source/lmp/tests/test_lammps_dpa_pt.py index 2768332c71..ace418c28e 100644 --- a/source/lmp/tests/test_lammps_dpa_pt.py +++ b/source/lmp/tests/test_lammps_dpa_pt.py @@ -475,6 +475,12 @@ def test_pair_deepmd_type_map(lammps_type_map) -> None: lammps_type_map.run(1) +def test_pair_deepmd_type_map_with_null(lammps_type_map) -> None: + lammps_type_map.pair_style(f"deepmd {pb_file.resolve()}") + lammps_type_map.pair_coeff("* * H NULL") + lammps_type_map.run(0) + + def test_pair_deepmd_real(lammps_real) -> None: lammps_real.pair_style(f"deepmd {pb_file.resolve()}") lammps_real.pair_coeff("* *")