Skip to content

Commit 4ed71d5

Browse files
kwen2501pytorchmergebot
authored andcommitted
[2/N][SymmMem] Add MemPool allocator and tests (pytorch#161471)
(Porting most of pytorch#161008) Hooking SymmetricMemory Allocator to MemPool so that user can create symmetric tensors with regular `torch.zeros`, `torch.arange` etc factories. Also so that our ops can have functional variants that create `out` tensors on symmetric memory. To end users, this PR supports a python UI as follows: ``` allocator = symm_mem.get_mempool_allocator(device) mempool = torch.cuda.MemPool(allocator) with torch.cuda.use_mem_pool(mempool): tensor = torch.arange(numel, dtype=dtype, device=device) ``` Added tests for both use cases above. Pull Request resolved: pytorch#161471 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#161470
1 parent 8dd5aa9 commit 4ed71d5

File tree

10 files changed

+138
-0
lines changed

10 files changed

+138
-0
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,7 @@ cc_library(
747747
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
748748
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
749749
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp",
750+
"torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp",
750751
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
751752
],
752753
)) + torch_sources,

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,7 @@ libtorch_cuda_distributed_extra_sources = [
755755
"torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu",
756756
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp",
757757
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
758+
"torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp",
758759
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
759760
]
760761

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ if(USE_CUDA)
581581
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu
582582
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp
583583
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu
584+
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp
584585
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
585586
)
586587
endif()

test/distributed/test_nvshmem.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,58 @@ def foo():
6565
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
6666
symm_mem.rendezvous(out, group=group_name)
6767

68+
@skipIfRocm
69+
def test_mempool_tensor_factory(self) -> None:
70+
"""
71+
Test the effectiveness of MemPool on tensor factory ops.
72+
"""
73+
self._init_device()
74+
group_name = dist.group.WORLD.group_name
75+
symm_mem.enable_symm_mem_for_group(group_name)
76+
77+
dtype = torch.float
78+
numel = 1024
79+
src_rank = 0
80+
81+
allocator = symm_mem.get_mempool_allocator(self.device)
82+
mempool = torch.cuda.MemPool(allocator)
83+
84+
with torch.cuda.use_mem_pool(mempool):
85+
if self.rank == src_rank:
86+
tensor = torch.arange(numel, dtype=dtype, device=self.device)
87+
else:
88+
tensor = torch.zeros(numel, dtype=dtype, device=self.device)
89+
90+
symm_mem.rendezvous(tensor, group=group_name)
91+
torch.ops.symm_mem.nvshmem_broadcast(tensor, group_name)
92+
self.assertEqual(tensor, torch.arange(numel, dtype=dtype, device=self.device))
93+
94+
@skipIfRocm
95+
def test_mempool_compute_ops(self) -> None:
96+
"""
97+
Apply MemPool context to a compute op that creates input to collective.
98+
"""
99+
self._init_device()
100+
group_name = dist.group.WORLD.group_name
101+
symm_mem.enable_symm_mem_for_group(group_name)
102+
103+
dtype = torch.float
104+
dim = 1024
105+
w = torch.ones(dim, dim, dtype=dtype, device=self.device)
106+
x0 = torch.ones(1, dim, dtype=dtype, device=self.device)
107+
108+
allocator = symm_mem.get_mempool_allocator(self.device)
109+
mempool = torch.cuda.MemPool(allocator)
110+
111+
with torch.cuda.use_mem_pool(mempool):
112+
x = x0 + self.rank
113+
y = torch.mm(x, w)
114+
115+
# y should be a symm tensor
116+
torch.ops.symm_mem.nvshmem_broadcast(y, group_name)
117+
expected = torch.mm(x0, w)
118+
self.assertEqual(y, expected)
119+
68120
@skipIfRocm
69121
def test_nvshmem_put(self) -> None:
70122
self._init_device()

torch/_C/_distributed_c10d.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,8 @@ class _SymmetricMemory:
769769
def set_backend(name: str) -> None: ...
770770
@staticmethod
771771
def get_backend(device: torch.device) -> Optional[str]: ...
772+
@staticmethod
773+
def get_mempool_allocator(device: torch.device) -> Any: ...
772774
@property
773775
def rank(self) -> int: ...
774776
@property

torch/csrc/distributed/c10d/init.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,9 @@ This class does not support ``__members__`` property.)");
11281128
&::c10d::symmetric_memory::has_multicast_support)
11291129
.def_static("set_backend", &::c10d::symmetric_memory::set_backend)
11301130
.def_static("get_backend", &::c10d::symmetric_memory::get_backend)
1131+
.def_static(
1132+
"get_mempool_allocator",
1133+
&::c10d::symmetric_memory::get_mempool_allocator)
11311134
.def_property_readonly("rank", &SymmetricMemory::get_rank)
11321135
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
11331136
.def_property_readonly(

torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,28 @@ TORCH_API bool has_multicast_support(
266266
return allocator->has_multicast_support(device_idx);
267267
}
268268
}
269+
270+
static std::unordered_map<c10::DeviceType, std::shared_ptr<c10::Allocator>>
271+
_mempool_allocators;
272+
273+
void register_mempool_allocator(
274+
c10::DeviceType device_type,
275+
std::shared_ptr<c10::Allocator> allocator) {
276+
_mempool_allocators[device_type] = std::move(allocator);
277+
}
278+
279+
// Get allocator for MemPool given device
280+
std::shared_ptr<c10::Allocator> get_mempool_allocator(c10::Device device) {
281+
auto it = _mempool_allocators.find(device.type());
282+
if (it == _mempool_allocators.end()) {
283+
TORCH_CHECK(
284+
false,
285+
"SymmetricMemory MemPool did not find backend for device type ",
286+
device.type());
287+
}
288+
return it->second;
289+
}
290+
269291
} // namespace c10d::symmetric_memory
270292

271293
namespace {

torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,11 @@ TORCH_API void set_backend(const std::string& name);
184184

185185
TORCH_API std::optional<std::string> get_backend(c10::Device device);
186186

187+
C10_EXPORT void register_mempool_allocator(
188+
c10::DeviceType device_type,
189+
std::shared_ptr<c10::Allocator> allocator);
190+
191+
TORCH_API std::shared_ptr<c10::Allocator> get_mempool_allocator(
192+
c10::Device device);
193+
187194
} // namespace c10d::symmetric_memory
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
2+
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
3+
4+
namespace {
5+
using namespace c10d::symmetric_memory;
6+
7+
// Alloc functor for MemPool
8+
void* cuda_symm_alloc(size_t size, int device, void* stream) {
9+
static auto allocator = get_allocator(c10::DeviceType::CUDA);
10+
TORCH_CHECK(
11+
allocator->name() == "NVSHMEM", "Only NVSHMEM backend is supported");
12+
// Note: this alloc functor works for the NVSHMEM and NCCL backends only,
13+
// because only these backends takes `nullopt` for the `group` argument which
14+
// is not given by MemPool's invocation (actually these two backends requires
15+
// it to be `nullopt`).
16+
return allocator->alloc(size, device, /*group_name=*/std::nullopt);
17+
}
18+
19+
// Free functor for MemPool
20+
void cuda_symm_free(void* ptr, size_t size, int device, void* stream) {
21+
static auto allocator = get_allocator(c10::DeviceType::CUDA);
22+
TORCH_CHECK(
23+
allocator->name() == "NVSHMEM", "Only NVSHMEM backend is supported");
24+
allocator->free(ptr);
25+
}
26+
27+
// Register allocator for CUDA MemPool
28+
struct RegisterCUDAMemPoolAllocator {
29+
RegisterCUDAMemPoolAllocator() {
30+
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator =
31+
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(
32+
cuda_symm_alloc, cuda_symm_free);
33+
register_mempool_allocator(c10::DeviceType::CUDA, allocator);
34+
}
35+
};
36+
37+
static RegisterCUDAMemPoolAllocator register_cuda_mempool_allocator_;
38+
39+
} // namespace

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,4 +1781,14 @@ def get_backend(device: _device) -> Optional[str]:
17811781
return _SymmetricMemory.get_backend(torch.device(device))
17821782

17831783

1784+
def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def]
1785+
r"""
1786+
Get the MemPool allocator for symmetric memory for a given device.
1787+
Args:
1788+
device (class:`torch.device` or str): the device for which to get the
1789+
MemPool allocator.
1790+
"""
1791+
return _SymmetricMemory.get_mempool_allocator(torch.device(device))
1792+
1793+
17841794
__all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"]

0 commit comments

Comments
 (0)