diff --git a/CMakeLists.txt b/CMakeLists.txt index 894d792..c98b426 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,9 @@ set(SRC_FILES src/ani/CpuANISymmetryFunctions.cpp src/pytorch/CFConv.cpp src/pytorch/CFConvNeighbors.cpp src/pytorch/SymmetryFunctions.cpp + src/pytorch/messages/messages.cpp + src/pytorch/messages/passMessagesCPU.cpp + src/pytorch/messages/passMessagesCUDA.cu src/pytorch/neighbors/getNeighborPairsCPU.cpp src/pytorch/neighbors/getNeighborPairsCUDA.cu src/pytorch/neighbors/neighbors.cpp @@ -65,7 +68,9 @@ add_test(TestEnergyShifter pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestEne add_test(TestOptimizedTorchANI pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestOptimizedTorchANI.py) add_test(TestSpeciesConverter pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestSpeciesConverter.py) add_test(TestSymmetryFunctions pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestSymmetryFunctions.py) +add_test(TestMessages pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/messages/TestMessages.py) add_test(TestNeighbors pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/neighbors/TestNeighbors.py) +add_test(TestPassMessages pytest -v --doctest-modules ${CMAKE_SOURCE_DIR}/src/pytorch/messages/passMessages.py) add_test(TestGetNeighborPairs pytest -v --doctest-modules ${CMAKE_SOURCE_DIR}/src/pytorch/neighbors/getNeighborPairs.py) # Installation @@ -78,9 +83,10 @@ install(FILES src/pytorch/__init__.py src/pytorch/OptimizedTorchANI.py src/pytorch/SpeciesConverter.py src/pytorch/SymmetryFunctions.py - src/pytorch/neighbors/__init__.py - src/pytorch/neighbors/getNeighborPairs.py DESTINATION ${Python_SITEARCH}/${NAME}) +install(FILES src/pytorch/messages/__init__.py + src/pytorch/messages/passMessages.py + DESTINATION ${Python_SITEARCH}/${NAME}/messages) install(FILES src/pytorch/neighbors/__init__.py src/pytorch/neighbors/getNeighborPairs.py DESTINATION ${Python_SITEARCH}/${NAME}/neighbors) \ No newline at end of file diff --git a/src/pytorch/messages/TestMessages.py b/src/pytorch/messages/TestMessages.py new file mode 100644 index 0000000..872731e --- /dev/null +++ b/src/pytorch/messages/TestMessages.py @@ -0,0 +1,91 @@ +import pytest +import torch as pt +from NNPOps.messages import passMessages + + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) +@pytest.mark.parametrize('num_pairs', [1, 2, 3, 4, 5, 10, 100]) +@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100]) +@pytest.mark.parametrize('num_states', [32, 64, 1024]) +def testPassMessageValues(device, dtype, num_pairs, num_atoms, num_states): + + if not pt.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + + # Generate random neighbors + neighbors = pt.randint(0, num_atoms, (2, num_pairs), dtype=pt.int32, device=device) + neighbors[:, pt.rand(num_pairs) > 0.5] = -1 + + # Generate random messages and states + messages = pt.randn((num_pairs, num_states), dtype=dtype, device=device) + states = pt.randn((num_atoms, num_states), dtype=dtype, device=device) + + # Compute reference + mask = pt.logical_and(neighbors[0] > -1, neighbors[1] > -1) + masked_neighbors = neighbors[:, mask].to(pt.long) + masked_messages = messages[mask, :] + ref_new_states = states.index_add(0, masked_neighbors[0], masked_messages)\ + .index_add(0, masked_neighbors[1], masked_messages) + + # Compute results + new_states = passMessages(neighbors, messages, states) + + # Check data type and device + assert new_states.device == neighbors.device + assert new_states.dtype == dtype + + # Check values + if dtype == pt.float32 and num_pairs > 10 and num_atoms < 10: + assert pt.allclose(ref_new_states, new_states, atol=1e-5, rtol=1e-3) + elif dtype == pt.float32: + assert pt.allclose(ref_new_states, new_states, atol=1e-6, rtol=1e-4) + else: + assert pt.allclose(ref_new_states, new_states, atol=1e-12, rtol=1e-8) + +@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) +@pytest.mark.parametrize('num_pairs', [1, 2, 3, 4, 5, 10, 100]) +@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100]) +@pytest.mark.parametrize('num_states', [32, 64, 1024]) +def testPassMessagesGrads(dtype, num_pairs, num_atoms, num_states): + + if not pt.cuda.is_available(): + pytest.skip('No GPU') + + # Generate random neighbors + neighbors = pt.randint(0, num_atoms, (2, num_pairs), dtype=pt.int32) + neighbors[:, pt.rand(num_pairs) > 0.5] = -1 + + # Generate random messages and states + messages = pt.randn((num_pairs, num_states), dtype=dtype) + states = pt.randn((num_atoms, num_states), dtype=dtype) + + # Compute CPU gradients + neighbors_cpu = neighbors.detach().cpu() + messages_cpu = messages.detach().cpu() + states_cpu = states.detach().cpu() + messages_cpu.requires_grad_() + states_cpu.requires_grad_() + passMessages(neighbors_cpu, messages_cpu, states_cpu).norm().backward() + + # Compute CUDA gradients + neighbors_cuda = neighbors.detach().cuda() + messages_cuda = messages.detach().cuda() + states_cuda = states.detach().cuda() + messages_cuda.requires_grad_() + states_cuda.requires_grad_() + passMessages(neighbors_cuda, messages_cuda, states_cuda).norm().backward() + + # Check type and device + assert messages_cuda.grad.dtype == dtype + assert states_cuda.grad.dtype == dtype + assert messages_cuda.grad.device == neighbors_cuda.device + assert states_cuda.grad.device == neighbors_cuda.device + + # Check gradients + if dtype == pt.float32: + assert pt.allclose(messages_cpu.grad, messages_cuda.grad.cpu(), atol=1e-6, rtol=1e-4) + assert pt.allclose(states_cpu.grad, states_cuda.grad.cpu(), atol=1e-6, rtol=1e-4) + else: + assert pt.allclose(messages_cpu.grad, messages_cuda.grad.cpu(), atol=1e-12, rtol=1e-8) + assert pt.allclose(states_cpu.grad, states_cuda.grad.cpu(), atol=1e-12, rtol=1e-8) \ No newline at end of file diff --git a/src/pytorch/messages/__init__.py b/src/pytorch/messages/__init__.py new file mode 100644 index 0000000..978fc67 --- /dev/null +++ b/src/pytorch/messages/__init__.py @@ -0,0 +1,5 @@ +''' +Message passing operations +''' + +from NNPOps.messages.passMessages import passMessages \ No newline at end of file diff --git a/src/pytorch/messages/messages.cpp b/src/pytorch/messages/messages.cpp new file mode 100644 index 0000000..c30adc6 --- /dev/null +++ b/src/pytorch/messages/messages.cpp @@ -0,0 +1,5 @@ +#include + +TORCH_LIBRARY(messages, m) { + m.def("passMessages(Tensor neighbors, Tensor messages, Tensor states) -> (Tensor states)"); +} \ No newline at end of file diff --git a/src/pytorch/messages/passMessages.py b/src/pytorch/messages/passMessages.py new file mode 100644 index 0000000..c13ca8a --- /dev/null +++ b/src/pytorch/messages/passMessages.py @@ -0,0 +1,60 @@ +from torch import ops, Tensor + + +def passMessages(neighbors: Tensor, messages: Tensor, states: Tensor) -> Tensor: + ''' + Pass messages between the neighbor atoms. + + Given a set of `num_atoms` atoms (each atom has a state with `num_features` + features) and a set of `num_neighbors` neighbor atom pairs (each pair has a + message with `num_features` features), the messages of the pairs are added + to the corresponding atom states. + + Parameters + ---------- + neighbors: `torch.Tensor` + Atom pair indices. The shape of the tensor is `(2, num_pairs)`. + The indices can be `[0, num_atoms)` or `-1` (ignored pairs). + See for the documentation of `NNPOps.neighbors.getNeighborPairs` for + details. + messages: `torch.Tensor` + Atom pair messages. The shape of the tensor is `(num_pairs, num_features)`. + For efficient, `num_features` has to be a multiple of 32 and <= 1024. + states: `torch.Tensor` + Atom states. The shape of the tensor is `(num_atoms, num_features)`. + + Returns + ------- + new_states: `torch.Tensor` + Update atom states. The shape of the tensor is `(num_atoms, num_features)`. + + Note + ---- + The operation is compatible with CUDA Grahps, i.e. the shapes of the output + tensors are independed of the values of input tensors. + + Examples + -------- + >>> import torch as pt + >>> from NNPOps.messages import passMessages + + >>> num_atoms = 4 + >>> num_neigbors = 3 + >>> num_features = 32 + + >>> neighbors = pt.tensor([[0, -1, 1], [0, -1, 3]], dtype=pt.int32) + + >>> messages = pt.ones((num_neigbors, 32)); messages[1] = 5 + >>> messages[:, 0] + tensor([1., 5., 1.]) + + >>> states = pt.zeros((num_atoms, num_features)); states[1] = 3 + >>> states[:, 0] + tensor([0., 3., 0., 0.]) + + >>> new_states = passMessages(neighbors, messages, states) + >>> new_states[:, 0] + tensor([2., 4., 0., 1.]) + ''' + + return ops.messages.passMessages(neighbors, messages, states) \ No newline at end of file diff --git a/src/pytorch/messages/passMessagesCPU.cpp b/src/pytorch/messages/passMessagesCPU.cpp new file mode 100644 index 0000000..ae3d4cb --- /dev/null +++ b/src/pytorch/messages/passMessagesCPU.cpp @@ -0,0 +1,43 @@ +#include + +using torch::kInt32; +using torch::logical_and; +using torch::Tensor; + +static Tensor forward(const Tensor& neighbors, const Tensor& messages, const Tensor& states) { + + TORCH_CHECK(neighbors.dim() == 2, "Expected \"neighbors\" to have two dimensions"); + TORCH_CHECK(neighbors.size(0) == 2, "Expected the 2nd dimension size of \"neighbors\" to be 2"); + TORCH_CHECK(neighbors.scalar_type() == kInt32, "Expected \"neighbors\" to have data type of int32"); + TORCH_CHECK(neighbors.is_contiguous(), "Expected \"neighbors\" to be contiguous"); + + TORCH_CHECK(messages.dim() == 2, "Expected \"messages\" to have two dimensions"); + TORCH_CHECK(messages.size(1) % 32 == 0, "Expected the 2nd dimension size of \"messages\" to be a multiple of 32"); + TORCH_CHECK(messages.size(1) <= 1024, "Expected the 2nd dimension size of \"messages\" to be less than 1024"); + TORCH_CHECK(messages.is_contiguous(), "Expected \"messages\" to be contiguous"); + + TORCH_CHECK(states.dim() == 2, "Expected \"states\" to have two dimensions"); + TORCH_CHECK(states.size(1) == messages.size(1), "Expected the 2nd dimension size of \"messages\" and \"states\" to be the same"); + TORCH_CHECK(states.scalar_type() == messages.scalar_type(), "Expected the data type of \"messages\" and \"states\" to be the same"); + TORCH_CHECK(states.is_contiguous(), "Expected \"messages\" to be contiguous"); + + const Tensor rows = neighbors[0]; + const Tensor columns = neighbors[1]; + + const int num_features = messages.size(1); + + const Tensor mask = logical_and(rows > -1, columns > -1); + const Tensor masked_rows = rows.masked_select(mask).to(torch::kLong); + const Tensor masked_columns = columns.masked_select(mask).to(torch::kLong); + const Tensor masked_messages = messages.masked_select(mask.unsqueeze(1)).reshape({-1, num_features}); + + Tensor new_states = states.clone(); + new_states.index_add_(0, masked_rows, masked_messages); + new_states.index_add_(0, masked_columns, masked_messages); + + return new_states; +} + +TORCH_LIBRARY_IMPL(messages, CPU, m) { + m.impl("passMessages", &forward); +} \ No newline at end of file diff --git a/src/pytorch/messages/passMessagesCUDA.cu b/src/pytorch/messages/passMessagesCUDA.cu new file mode 100644 index 0000000..7553767 --- /dev/null +++ b/src/pytorch/messages/passMessagesCUDA.cu @@ -0,0 +1,123 @@ +#include +#include +#include + +#include "common/accessor.cuh" +#include "common/atomicAdd.cuh" + +using c10::cuda::CUDAStreamGuard; +using c10::cuda::getCurrentCUDAStream; +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using torch::autograd::tensor_list; +using torch::kInt32; +using torch::Tensor; +using torch::TensorOptions; + +template __global__ void kernel_forward( + const Accessor neighbors, + const Accessor messages, + Accessor new_states +) { + const int32_t i_neig = blockIdx.x; + const int32_t i_dir = blockIdx.y; + const int32_t i_atom = neighbors[i_dir][i_neig]; + if (i_atom < 0) return; + + const int32_t i_feat = threadIdx.x; + atomicAdd(&new_states[i_atom][i_feat], messages[i_neig][i_feat]); +} + +template __global__ void kernel_backward( + const Accessor neighbors, + const Accessor grad_new_state, + Accessor grad_messages +) { + const int32_t i_neig = blockIdx.x; + const int32_t i_dir = blockIdx.y; + const int32_t i_atom = neighbors[i_dir][i_neig]; + if (i_atom < 0) return; + + const int32_t i_feat = threadIdx.x; + atomicAdd(&grad_messages[i_neig][i_feat], grad_new_state[i_atom][i_feat]); +} + +class Autograd : public Function { +public: + static tensor_list forward(AutogradContext* ctx, + const Tensor& neighbors, + const Tensor& messages, + const Tensor& states) { + + TORCH_CHECK(neighbors.dim() == 2, "Expected \"neighbors\" to have two dimensions"); + TORCH_CHECK(neighbors.size(0) == 2, "Expected the 2nd dimension size of \"neighbors\" to be 2"); + TORCH_CHECK(neighbors.scalar_type() == kInt32, "Expected \"neighbors\" to have data type of int32"); + TORCH_CHECK(neighbors.is_contiguous(), "Expected \"neighbors\" to be contiguous"); + + TORCH_CHECK(messages.dim() == 2, "Expected \"messages\" to have two dimensions"); + TORCH_CHECK(messages.size(1) % 32 == 0, "Expected the 2nd dimension size of \"messages\" to be a multiple of 32"); + TORCH_CHECK(messages.size(1) <= 1024, "Expected the 2nd dimension size of \"messages\" to be less than 1024"); + TORCH_CHECK(messages.is_contiguous(), "Expected \"messages\" to be contiguous"); + + TORCH_CHECK(states.dim() == 2, "Expected \"states\" to have two dimensions"); + TORCH_CHECK(states.size(1) == messages.size(1), "Expected the 2nd dimension size of \"messages\" and \"states\" to be the same"); + TORCH_CHECK(states.scalar_type() == messages.scalar_type(), "Expected the data type of \"messages\" and \"states\" to be the same"); + TORCH_CHECK(states.is_contiguous(), "Expected \"messages\" to be contiguous"); + + const int num_neighbors = neighbors.size(1); + const int num_features = messages.size(1); + + const dim3 blocks(num_neighbors, 2); + const dim3 threads(num_features); + const auto stream = getCurrentCUDAStream(neighbors.get_device()); + + Tensor new_states = states.clone(); + + AT_DISPATCH_FLOATING_TYPES(messages.scalar_type(), "passMessages::forward", [&]() { + const CUDAStreamGuard guard(stream); + kernel_forward<<>>( + get_accessor(neighbors), + get_accessor(messages), + get_accessor(new_states)); + }); + + ctx->save_for_backward({neighbors}); + + return {new_states}; + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { + + const Tensor neighbors = ctx->get_saved_variables()[0]; + const Tensor grad_new_state = grad_inputs[0]; + + const int num_neighbors = neighbors.size(1); + const int num_features = grad_new_state.size(1); + + const dim3 blocks(num_neighbors, 2); + const dim3 threads(num_features); + const auto stream = getCurrentCUDAStream(neighbors.get_device()); + + Tensor grad_messages = torch::zeros({num_neighbors, num_features}, grad_new_state.options()); + + AT_DISPATCH_FLOATING_TYPES(grad_new_state.scalar_type(), "passMessages::backward", [&]() { + const CUDAStreamGuard guard(stream); + kernel_backward<<>>( + get_accessor(neighbors), + get_accessor(grad_new_state), + get_accessor(grad_messages)); + }); + + return {Tensor(), // grad_neighbors + grad_messages, + grad_new_state.clone()}; // grad_state + } +}; + +TORCH_LIBRARY_IMPL(messages, AutogradCUDA, m) { + m.impl("passMessages", [](const Tensor& neighbors, + const Tensor& messages, + const Tensor& states) { + return Autograd::apply(neighbors, messages, states)[0]; + }); +} \ No newline at end of file