Skip to content

Commit 22d46b5

Browse files
[CUDA] revert PR 130472 (pytorch#163379)
[CUDA] revert PR 130472 (pytorch#162950) This change may also resolve pytorch#161789, though verification is still needed. PR pytorch#130472 would introduced the problem of freeing the same address without clean metadata. according to the below discussion, reverted it. Pull Request resolved: pytorch#162950 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/syed-ahmed (cherry picked from commit 4a160da) Co-authored-by: thenumberouscode <[email protected]>
1 parent d1b63e2 commit 22d46b5

File tree

4 files changed

+17
-105
lines changed

4 files changed

+17
-105
lines changed

aten/src/ATen/test/cuda_allocator_test.cpp

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,6 @@
55

66
#include <ATen/test/allocator_clone_test.h>
77

8-
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
9-
108
TEST(AllocatorTestCUDA, test_clone) {
119
test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
1210
}
13-
14-
static int called_dummy_free_0 = 0;
15-
static int called_dummy_free_1 = 0;
16-
17-
void* dummy_alloc_0(size_t size, int device, void* stream) {return nullptr;}
18-
void dummy_free_0(void* data, size_t size, int device, void* stream) {
19-
called_dummy_free_0++;
20-
}
21-
void dummy_free_1(void* data, size_t size, int device, void* stream) {
22-
called_dummy_free_1++;
23-
}
24-
25-
// Tests that data_ptrs have their respective deleters
26-
// when mixing allocators
27-
TEST(AllocatorTestCUDA, test_pluggable_allocator_deleters) {
28-
// Create a tensor with dummy_allocator_0, where dummy_free_0 is the deleter
29-
auto dummy_allocator_0 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_0);
30-
c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_0.get());
31-
at::Tensor a = at::empty({0}, at::TensorOptions().device(at::kCUDA));
32-
33-
// Create a tensor with dummy_allocator_1, where dummy_free_1 is the deleter
34-
auto dummy_allocator_1 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_1);
35-
c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_1.get());
36-
at::Tensor b = at::empty({0}, at::TensorOptions().device(at::kCUDA));
37-
38-
// Manually use a's deleter
39-
auto* ctx = a.storage().data_ptr().get_context();
40-
a.storage().data_ptr().get_deleter()(ctx);
41-
a.storage().mutable_data_ptr().release_context();
42-
43-
// a's deleter is dummy_free_0
44-
// dummy_free_0 should be called above, so called_dummy_free_0 should be 1
45-
ASSERT_TRUE(called_dummy_free_0 == 1);
46-
47-
// Manually use b's deleter
48-
ctx = b.storage().data_ptr().get_context();
49-
b.storage().data_ptr().get_deleter()(ctx);
50-
b.storage().mutable_data_ptr().release_context();
51-
52-
// b's deleter is dummy_free_1
53-
// dummy_free_1 should be called above, so called_dummy_free_1 should be 1
54-
ASSERT_TRUE(called_dummy_free_1 == 1);
55-
}

torch/csrc/cuda/CUDAPluggableAllocator.cpp

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,6 @@
77

88
namespace torch::cuda::CUDAPluggableAllocator {
99

10-
CUDAPluggableAllocatorDeleterContext::CUDAPluggableAllocatorDeleterContext(
11-
std::function<FreeFuncType> free_fn,
12-
void* data,
13-
size_t size,
14-
int device,
15-
cudaStream_t stream)
16-
: free_fn_(std::move(free_fn)),
17-
data_(data),
18-
size_(size),
19-
device_(device),
20-
stream_(stream) {}
21-
22-
void CUDAPluggableAllocatorDeleterContext::free() {
23-
free_fn_(data_, size_, device_, stream_);
24-
delete this;
25-
}
26-
2710
int device_count = 0;
2811

2912
void custom_raw_deleter(void* ptr);
@@ -41,8 +24,8 @@ _AllocationMetadata::_AllocationMetadata(
4124
// This avoids having to link against libtorch for C++ based custom allocators
4225
// And also use this from python
4326
CUDAPluggableAllocator::CUDAPluggableAllocator(
44-
std::function<MallocFuncType> alloc_fn,
45-
std::function<FreeFuncType> free_fn)
27+
std::function<void*(size_t, int, cudaStream_t)> alloc_fn,
28+
std::function<void(void*, size_t, int, cudaStream_t)> free_fn)
4629
: alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {}
4730

4831
CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other)
@@ -114,10 +97,8 @@ c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) {
11497
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
11598
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
11699
void* r = this->malloc(size, device, stream);
117-
auto* ctx = new CUDAPluggableAllocatorDeleterContext(
118-
free_fn_, r, size, device, stream);
119100
c10::DataPtr data_ptr = {
120-
r, ctx, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)};
101+
r, r, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)};
121102
return data_ptr;
122103
}
123104

@@ -382,8 +363,8 @@ getCurrentAllocator() {
382363
// TODO: add more functions in the argument
383364
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
384365
createCustomAllocator(
385-
std::function<MallocFuncType> alloc_fn,
386-
std::function<FreeFuncType> free_fn) {
366+
std::function<void*(size_t, int, cudaStream_t)> alloc_fn,
367+
std::function<void(void*, size_t, int, cudaStream_t)> free_fn) {
387368
std::shared_ptr<CUDAPluggableAllocator> allocator(
388369
new CUDAPluggableAllocator(std::move(alloc_fn), std::move(free_fn)));
389370
allocator->init(device_count);
@@ -400,8 +381,8 @@ void changeCurrentAllocator(
400381
current_custom_allocator = allocator;
401382
}
402383

403-
void custom_raw_deleter(void* ctx) {
404-
reinterpret_cast<CUDAPluggableAllocatorDeleterContext*>(ctx)->free();
384+
void custom_raw_deleter(void* ptr) {
385+
current_custom_allocator->raw_delete(ptr);
405386
}
406387

407388
} // namespace torch::cuda::CUDAPluggableAllocator

torch/csrc/cuda/CUDAPluggableAllocator.h

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,6 @@
1111

1212
namespace torch::cuda::CUDAPluggableAllocator {
1313

14-
using MallocFuncType = void*(size_t, int, cudaStream_t);
15-
using FreeFuncType = void(void*, size_t, int, cudaStream_t);
16-
17-
// A CUDAPluggableAllocatorDeleterContext object is used as the `ctx`
18-
// argument for DataPtr. We need context because a user can use
19-
// multiple allocators in the same PyTorch program, and
20-
// the allocators can have different free functions, such as:
21-
// free, cudaFree, cudaFreeAsync, ncclMemFree etc.
22-
struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext {
23-
explicit CUDAPluggableAllocatorDeleterContext(
24-
std::function<FreeFuncType> free_fn,
25-
void* data,
26-
size_t size,
27-
int device,
28-
cudaStream_t stream);
29-
30-
void free();
31-
32-
private:
33-
std::function<FreeFuncType> free_fn_;
34-
void* data_;
35-
size_t size_;
36-
int device_;
37-
cudaStream_t stream_{};
38-
};
39-
4014
#if defined(USE_ROCM)
4115
using streamType = c10::hip::HIPStream;
4216
#else
@@ -49,8 +23,8 @@ getCurrentAllocator();
4923
TORCH_CUDA_CPP_API std::shared_ptr<
5024
c10::cuda::CUDACachingAllocator::CUDAAllocator>
5125
createCustomAllocator(
52-
std::function<MallocFuncType> alloc_fn,
53-
std::function<FreeFuncType> free_fn);
26+
std::function<void*(size_t, int, cudaStream_t)> alloc_fn,
27+
std::function<void(void*, size_t, int, cudaStream_t)> free_fn);
5428
TORCH_CUDA_CPP_API void changeCurrentAllocator(
5529
const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>&
5630
allocator);
@@ -69,8 +43,8 @@ struct _AllocationMetadata {
6943
struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
7044
: public c10::cuda::CUDACachingAllocator::CUDAAllocator {
7145
CUDAPluggableAllocator(
72-
std::function<MallocFuncType> alloc_fn,
73-
std::function<FreeFuncType> free_fn);
46+
std::function<void*(size_t, int, cudaStream_t)> alloc_fn,
47+
std::function<void(void*, size_t, int, cudaStream_t)> free_fn);
7448

7549
CUDAPluggableAllocator(CUDAPluggableAllocator& other);
7650
CUDAPluggableAllocator(CUDAPluggableAllocator&& other) = delete;
@@ -173,8 +147,8 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
173147
void copy_data(void* dest, const void* src, std::size_t count) const final;
174148

175149
protected:
176-
std::function<MallocFuncType> alloc_fn_;
177-
std::function<FreeFuncType> free_fn_;
150+
std::function<void*(size_t, int, cudaStream_t)> alloc_fn_;
151+
std::function<void(void*, size_t, int, cudaStream_t)> free_fn_;
178152
std::function<void(int)> init_fn_;
179153
std::function<void()> reset_fn_;
180154
std::function<void(double, int)> memory_fraction_fn_;

torch/csrc/cuda/Module.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,14 +1274,16 @@ static void registerCudaPluggableAllocator(PyObject* module) {
12741274
self.set_release_pool(func);
12751275
});
12761276
m.def("_cuda_customAllocator", [](uint64_t malloc_ptr, uint64_t free_ptr) {
1277-
using namespace torch::cuda::CUDAPluggableAllocator;
1277+
using MallocFuncType = void*(size_t, int, cudaStream_t);
1278+
using FreeFuncType = void(void*, size_t, int, cudaStream_t);
12781279
std::function<MallocFuncType> malloc_fn =
12791280
// NOLINTNEXTLINE(performance-no-int-to-ptr)
12801281
reinterpret_cast<MallocFuncType*>(malloc_ptr);
12811282
std::function<FreeFuncType> free_fn =
12821283
// NOLINTNEXTLINE(performance-no-int-to-ptr)
12831284
reinterpret_cast<FreeFuncType*>(free_ptr);
1284-
return createCustomAllocator(malloc_fn, free_fn);
1285+
return torch::cuda::CUDAPluggableAllocator::createCustomAllocator(
1286+
malloc_fn, free_fn);
12851287
});
12861288

12871289
// NOLINTNEXTLINE(bugprone-unused-raii)

0 commit comments

Comments
 (0)