Skip to content
Open
90 changes: 80 additions & 10 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include "device.h"
#include "errors.h"

#ifdef USE_MPI
#include <mpi.h>
#ifdef OMPI_MPI_H
#include <mpi-ext.h>
#endif
#endif

using namespace deepmd;

void DeepPotPT::translate_error(std::function<void()> f) {
Expand Down Expand Up @@ -174,16 +181,79 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
nlist_data.padding();
if (do_message_passing) {
int nswap = lmp_list.nswap;

std::vector<int> sendnum_new(nswap, 0);
std::vector<int> sendlist_new;
sendlist_new.reserve(
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0)
);
for (int s = 0; s < nswap; ++s) {
int cnt = 0;
for (int k = 0; k < lmp_list.sendnum[s]; ++k) {
const int old_idx = lmp_list.sendlist[s][k];
int mapped = (old_idx >= 0 && old_idx < (int)fwd_map.size())
? fwd_map[old_idx]
: -1;
if (mapped >= 0) {
sendlist_new.push_back(mapped);
++cnt;
}
}
sendnum_new[s] = cnt;
}

std::vector<int> recvnum_new(nswap, 0);
#ifdef MPI_FOUND
if (lmp_list.world) {
MPI_Comm comm = *static_cast<MPI_Comm*>(lmp_list.world);
const int TAG_BASE = 0x7a31;
for (int s = 0; s < nswap; ++s) {
const int send_to = lmp_list.sendproc[s];
const int recv_from = lmp_list.recvproc[s];
int send_cnt = sendnum_new[s];
int recv_cnt = 0;
MPI_Sendrecv(&send_cnt, 1, MPI_INT, send_to, TAG_BASE + s,
&recv_cnt, 1, MPI_INT, recv_from, TAG_BASE + s,
comm, MPI_STATUS_IGNORE);
recvnum_new[s] = recv_cnt;
}
} else
#endif
{
for (int s = 0; s < nswap; ++s) recvnum_new[s] = sendnum_new[s];
}

std::vector<int> firstrecv_new(nswap, 0);
int acc = 0;
for (int s = 0; s < nswap; ++s) {
firstrecv_new[s] = acc;
acc += recvnum_new[s];
}

torch::Tensor sendproc_tensor =
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
torch::Tensor recvproc_tensor =
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
torch::Tensor firstrecv_tensor =
torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
torch::Tensor recvnum_tensor =
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
torch::Tensor sendnum_tensor =
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);

torch::Tensor firstrecv_tensor =
torch::from_blob(firstrecv_new.data(), {nswap}, int32_option).clone();
torch::Tensor recvnum_tensor =
torch::from_blob(recvnum_new.data(), {nswap}, int32_option).clone();
torch::Tensor sendnum_tensor =
torch::from_blob(sendnum_new.data(), {nswap}, int32_option).clone();

torch::Tensor sendlist_tensor =
torch::from_blob(sendlist_new.data(),
{ static_cast<long>(sendlist_new.size()) },
int32_option).clone();


// torch::Tensor firstrecv_tensor =
// torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
// torch::Tensor recvnum_tensor =
// torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
// torch::Tensor sendnum_tensor =
// torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
torch::Tensor communicator_tensor;
if (lmp_list.world == 0) {
communicator_tensor = torch::empty({1}, torch::kInt64);
Expand All @@ -193,10 +263,10 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
}

torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
int total_send =
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
torch::Tensor sendlist_tensor =
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
// int total_send =
// std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
// torch::Tensor sendlist_tensor =
// torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
comm_dict.insert_or_assign("send_list", sendlist_tensor);
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
Expand Down
Loading