Skip to content

Commit e03a385

Browse files
committed
Add CUDA memory allocation retrying with GC to torch patch
1 parent 2631168 commit e03a385

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed

graalpython/lib-graalpython/patches/torch-2.4.1.patch

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,96 @@
1+
diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp
2+
index 11bea6056..ca182f4ed 100644
3+
--- a/c10/cuda/CUDACachingAllocator.cpp
4+
+++ b/c10/cuda/CUDACachingAllocator.cpp
5+
@@ -924,6 +924,8 @@ class DeviceCachingAllocator {
6+
// XXX - maybe we should generalize and have multiple events
7+
std::vector<OutOfMemoryObserver> oom_observers_;
8+
9+
+ std::vector<OutOfMemoryRetrier> oom_retriers_;
10+
+
11+
std::vector<AllocatorTraceTracker> trace_trackers_;
12+
13+
// mapping from block to a stream_set, containing streams on which the block
14+
@@ -995,6 +997,10 @@ class DeviceCachingAllocator {
15+
oom_observers_.emplace_back(std::move(observer));
16+
}
17+
18+
+ void attachOutOfMemoryRetrier(OutOfMemoryRetrier retrier) {
19+
+ oom_retriers_.emplace_back(std::move(retrier));
20+
+ }
21+
+
22+
void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
23+
std::unique_lock<std::recursive_mutex> lock(mutex);
24+
trace_trackers_.emplace_back(std::move(tracker));
25+
@@ -1019,6 +1025,9 @@ class DeviceCachingAllocator {
26+
// to have...
27+
auto context = maybeGatherContext(RecordContext::STATE);
28+
29+
+ int retries = 10;
30+
+retry:
31+
+
32+
std::unique_lock<std::recursive_mutex> lock(mutex);
33+
34+
if (C10_LIKELY(captures_underway.empty())) {
35+
@@ -1072,6 +1081,13 @@ class DeviceCachingAllocator {
36+
}
37+
38+
if (!block_found) {
39+
+ if (retries && !oom_retriers_.empty()) {
40+
+ retries -= 1;
41+
+ for (const auto& retrier : oom_retriers_) {
42+
+ retrier();
43+
+ }
44+
+ goto retry;
45+
+ }
46+
// For any error code other than cudaErrorMemoryAllocation,
47+
// alloc_block should have thrown an exception already.
48+
TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);
49+
@@ -3046,6 +3062,12 @@ class NativeCachingAllocator : public CUDAAllocator {
50+
}
51+
}
52+
53+
+ void attachOutOfMemoryRetrier(OutOfMemoryRetrier retrier) override {
54+
+ for (auto& allocator : device_allocator) {
55+
+ allocator->attachOutOfMemoryRetrier(retrier);
56+
+ }
57+
+ }
58+
+
59+
void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override {
60+
for (auto& allocator : device_allocator) {
61+
allocator->attachAllocatorTraceTracker(tracker);
62+
diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h
63+
index 438ed8d77..a76348e2f 100644
64+
--- a/c10/cuda/CUDACachingAllocator.h
65+
+++ b/c10/cuda/CUDACachingAllocator.h
66+
@@ -241,6 +241,8 @@ using OutOfMemoryObserver = std::function<void(
67+
68+
using AllocatorTraceTracker = std::function<void(const TraceEntry&)>;
69+
70+
+using OutOfMemoryRetrier = std::function<void()>;
71+
+
72+
class CUDAAllocator : public Allocator {
73+
public:
74+
virtual void* raw_alloc(size_t nbytes) = 0;
75+
@@ -290,6 +292,7 @@ class CUDAAllocator : public Allocator {
76+
size_t alloc_trace_max_entries,
77+
RecordContext when) = 0;
78+
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
79+
+ virtual void attachOutOfMemoryRetrier(OutOfMemoryRetrier retrier) {};
80+
81+
// Attached AllocatorTraceTracker callbacks will be called while the
82+
// per-device allocator lock is held. Any additional locks taken from within
83+
@@ -444,6 +447,10 @@ inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
84+
return get()->attachOutOfMemoryObserver(std::move(observer));
85+
}
86+
87+
+inline void attachOutOfMemoryRetrier(OutOfMemoryRetrier retrier) {
88+
+ return get()->attachOutOfMemoryRetrier(std::move(retrier));
89+
+}
90+
+
91+
inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
92+
return get()->attachAllocatorTraceTracker(std::move(tracker));
93+
}
194
diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp
295
index 252cb3b14..2b71b93eb 100644
396
--- a/functorch/csrc/dim/dim.cpp
@@ -527,6 +620,28 @@ index 78c4a546d..182ad0b47 100644
527620
throw python_error();
528621
}
529622
stop = clip_val(stop);
623+
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
624+
index 4197c2aa5..d78e60b2b 100644
625+
--- a/torch/csrc/cuda/Module.cpp
626+
+++ b/torch/csrc/cuda/Module.cpp
627+
@@ -1343,6 +1343,17 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) {
628+
poison_fork();
629+
at::globalContext().lazyInitCUDA();
630+
631+
+ // GraalPy change
632+
+ auto retrier = [](){
633+
+ py::gil_scoped_acquire g;
634+
+ PyObject* gcmodule = PyImport_ImportModule("gc");
635+
+ if (gcmodule) {
636+
+ PyObject_CallMethod(gcmodule, "collect", NULL);
637+
+ }
638+
+ PyErr_Clear();
639+
+ };
640+
+ c10::cuda::CUDACachingAllocator::attachOutOfMemoryRetrier(std::move(retrier));
641+
+
642+
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
643+
if (!m)
644+
throw python_error();
530645
diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c
531646
index c301da982..a2668be20 100644
532647
--- a/torch/csrc/dynamo/cpython_defs.c
@@ -707,3 +822,18 @@ index 92e6e2d3a..4d2ec0bfe 100644
707822
auto new_frame = PyFrame_GetBack(frame);
708823
Py_DECREF(frame);
709824
frame = new_frame;
825+
diff --git a/torch/csrc/profiler/python/combined_traceback.cpp b/torch/csrc/profiler/python/combined_traceback.cpp
826+
index f9e20541e..f5d4d1375 100644
827+
--- a/torch/csrc/profiler/python/combined_traceback.cpp
828+
+++ b/torch/csrc/profiler/python/combined_traceback.cpp
829+
@@ -86,8 +86,8 @@ struct PythonTraceback : public CapturedTraceback::Python {
830+
}
831+
for (const auto& f : to_symbolize) {
832+
auto f_code = (PyCodeObject*)f.code;
833+
- py::handle filename = f_code->co_filename;
834+
- py::handle funcname = f_code->co_name;
835+
+ py::object filename = pybind11::reinterpret_steal<py::object>(PyCode_GetFileName(f_code));
836+
+ py::object funcname = pybind11::reinterpret_steal<py::object>(PyCode_GetName(f_code));
837+
auto lineno = PyCode_Addr2Line(f_code, f.lasti);
838+
result.tracebacks.emplace_back();
839+
result.tracebacks.back().push_back(result.all_frames.size());

graalpython/lib-graalpython/patches/torch-2.7.0.patch

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,96 @@
1+
diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp
2+
index 4a1e4654f..8b0ea304c 100644
3+
--- a/c10/cuda/CUDACachingAllocator.cpp
4+
+++ b/c10/cuda/CUDACachingAllocator.cpp
5+
@@ -1099,6 +1099,8 @@ class DeviceCachingAllocator {
6+
// XXX - maybe we should generalize and have multiple events
7+
std::vector<OutOfMemoryObserver> oom_observers_;
8+
9+
+ std::vector<OutOfMemoryRetrier> oom_retriers_;
10+
+
11+
std::vector<AllocatorTraceTracker> trace_trackers_;
12+
13+
// mapping from block to a stream_set, containing streams on which the block
14+
@@ -1167,6 +1169,10 @@ class DeviceCachingAllocator {
15+
oom_observers_.emplace_back(std::move(observer));
16+
}
17+
18+
+ void attachOutOfMemoryRetrier(OutOfMemoryRetrier retrier) {
19+
+ oom_retriers_.emplace_back(std::move(retrier));
20+
+ }
21+
+
22+
void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
23+
std::unique_lock<std::recursive_mutex> lock(mutex);
24+
trace_trackers_.emplace_back(std::move(tracker));
25+
@@ -1191,6 +1197,9 @@ class DeviceCachingAllocator {
26+
// to have...
27+
auto context = maybeGatherContext(RecordContext::STATE);
28+
29+
+ int retries = 10;
30+
+retry:
31+
+
32+
std::unique_lock<std::recursive_mutex> lock(mutex);
33+
34+
if (C10_LIKELY(captures_underway.empty())) {
35+
@@ -1244,6 +1253,13 @@ class DeviceCachingAllocator {
36+
}
37+
38+
if (!block_found) {
39+
+ if (retries && !oom_retriers_.empty()) {
40+
+ retries -= 1;
41+
+ for (const auto& retrier : oom_retriers_) {
42+
+ retrier();
43+
+ }
44+
+ goto retry;
45+
+ }
46+
// For any error code other than cudaErrorMemoryAllocation,
47+
// alloc_block should have thrown an exception already.
48+
TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);
49+
@@ -3486,6 +3502,12 @@ class NativeCachingAllocator : public CUDAAllocator {
50+
}
51+
}
52+
53+
+ void attachOutOfMemoryRetrier(OutOfMemoryRetrier retrier) override {
54+
+ for (auto& allocator : device_allocator) {
55+
+ allocator->attachOutOfMemoryRetrier(retrier);
56+
+ }
57+
+ }
58+
+
59+
void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override {
60+
for (auto& allocator : device_allocator) {
61+
allocator->attachAllocatorTraceTracker(tracker);
62+
diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h
63+
index df31a11da..55b8e6225 100644
64+
--- a/c10/cuda/CUDACachingAllocator.h
65+
+++ b/c10/cuda/CUDACachingAllocator.h
66+
@@ -191,6 +191,8 @@ using OutOfMemoryObserver = std::function<void(
67+
68+
using AllocatorTraceTracker = std::function<void(const TraceEntry&)>;
69+
70+
+using OutOfMemoryRetrier = std::function<void()>;
71+
+
72+
struct ShareableHandle {
73+
ptrdiff_t offset;
74+
std::string handle;
75+
@@ -268,6 +270,7 @@ class CUDAAllocator : public Allocator {
76+
virtual void recordAnnotation(
77+
const std::vector<std::pair<std::string, std::string>>& md) {}
78+
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
79+
+ virtual void attachOutOfMemoryRetrier(OutOfMemoryRetrier retrier) {};
80+
81+
// Attached AllocatorTraceTracker callbacks will be called while the
82+
// per-device allocator lock is held. Any additional locks taken from within
83+
@@ -440,6 +443,10 @@ inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
84+
return get()->attachOutOfMemoryObserver(std::move(observer));
85+
}
86+
87+
+inline void attachOutOfMemoryRetrier(OutOfMemoryRetrier retrier) {
88+
+ return get()->attachOutOfMemoryRetrier(std::move(retrier));
89+
+}
90+
+
91+
inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
92+
return get()->attachAllocatorTraceTracker(std::move(tracker));
93+
}
194
diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp
295
index 23179ad0e..ad9dbdbf7 100644
396
--- a/functorch/csrc/dim/dim.cpp
@@ -567,6 +660,28 @@ index 7efab1dcf..67b3cf44e 100644
567660
throw python_error();
568661
}
569662
stop = clip_val(stop);
663+
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
664+
index b81ff5d4e..c44f0b617 100644
665+
--- a/torch/csrc/cuda/Module.cpp
666+
+++ b/torch/csrc/cuda/Module.cpp
667+
@@ -1516,6 +1516,17 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) {
668+
poison_fork();
669+
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
670+
671+
+ // GraalPy change
672+
+ auto retrier = [](){
673+
+ py::gil_scoped_acquire g;
674+
+ PyObject* gcmodule = PyImport_ImportModule("gc");
675+
+ if (gcmodule) {
676+
+ PyObject_CallMethod(gcmodule, "collect", NULL);
677+
+ }
678+
+ PyErr_Clear();
679+
+ };
680+
+ c10::cuda::CUDACachingAllocator::attachOutOfMemoryRetrier(std::move(retrier));
681+
+
682+
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
683+
if (!m)
684+
throw python_error();
570685
diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c
571686
index b68ef894a..0837d95be 100644
572687
--- a/torch/csrc/dynamo/cpython_defs.c
@@ -812,3 +927,18 @@ index 876186743..041348257 100644
812927
auto new_frame = PyFrame_GetBack(frame);
813928
Py_DECREF(frame);
814929
frame = new_frame;
930+
diff --git a/torch/csrc/profiler/python/combined_traceback.cpp b/torch/csrc/profiler/python/combined_traceback.cpp
931+
index f9e20541e..f5d4d1375 100644
932+
--- a/torch/csrc/profiler/python/combined_traceback.cpp
933+
+++ b/torch/csrc/profiler/python/combined_traceback.cpp
934+
@@ -86,8 +86,8 @@ struct PythonTraceback : public CapturedTraceback::Python {
935+
}
936+
for (const auto& f : to_symbolize) {
937+
auto f_code = (PyCodeObject*)f.code;
938+
- py::handle filename = f_code->co_filename;
939+
- py::handle funcname = f_code->co_name;
940+
+ py::object filename = pybind11::reinterpret_steal<py::object>(PyCode_GetFileName(f_code));
941+
+ py::object funcname = pybind11::reinterpret_steal<py::object>(PyCode_GetName(f_code));
942+
auto lineno = PyCode_Addr2Line(f_code, f.lasti);
943+
result.tracebacks.emplace_back();
944+
result.tracebacks.back().push_back(result.all_frames.size());

0 commit comments

Comments
 (0)