Skip to content

Commit f7db51c

Browse files
sraikund16amathewc
authored andcommitted
[GPU Snapshot] Add Clear History Flag (pytorch#149352)
Summary: Oftentimes, users complain that a bunch of extra events are prepended to their desired GPU snapshot. This is because they usually attach an OOM logger without knowing and when they go to collect the actual snapshot, it adds all the OOM logger contents. Since OOM and regular snapshot use the same backend, we currently don't have the infra in place to split these snapshots. As a solution we add a flag to the snapshot frontend to clear out the history when starting the auto-trace record memory history. A more thorough solution would be to have a user pass in a handle and to have snapshots per handle to seperate the events. However, this would likely be complicated and more work than it is worth as we would have to change the callbacks in the caching allocator and pass these objects between python and cpp. Test Plan: See diff below Differential Revision: D71159720 Pull Request resolved: pytorch#149352 Approved by: https://github.com/eqy, https://github.com/aaronenyeshi
1 parent 7bc36c2 commit f7db51c

File tree

10 files changed

+41
-20
lines changed

10 files changed

+41
-20
lines changed

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,14 +1119,15 @@ class DeviceCachingAllocator {
11191119
bool enabled,
11201120
CreateContextFn context_recorder,
11211121
size_t alloc_buffer_max_entries,
1122-
RecordContext when) {
1122+
RecordContext when,
1123+
bool clearHistory) {
11231124
std::unique_lock<std::recursive_mutex> lock(mutex);
11241125
TORCH_CHECK(when == RecordContext::NEVER || context_recorder);
11251126
record_history = enabled;
11261127
context_recorder_.store(record_history ? context_recorder : nullptr);
11271128
alloc_buffer.setMaxEntries(alloc_buffer_max_entries);
11281129
record_context_ = enabled ? when : RecordContext::NEVER;
1129-
if (!enabled) {
1130+
if (!enabled || clearHistory) {
11301131
alloc_buffer.clear();
11311132
}
11321133
}
@@ -3441,13 +3442,18 @@ class NativeCachingAllocator : public CUDAAllocator {
34413442
bool enabled,
34423443
CreateContextFn context_recorder,
34433444
size_t alloc_buffer_max_entries,
3444-
RecordContext when) override {
3445+
RecordContext when,
3446+
bool clearHistory) override {
34453447
record_history = enabled;
34463448
annotation_buffer.setMaxEntries(alloc_buffer_max_entries);
34473449
annotation_buffer.clear();
34483450
for (auto& allocator : device_allocator) {
34493451
allocator->recordHistory(
3450-
enabled, context_recorder, alloc_buffer_max_entries, when);
3452+
enabled,
3453+
context_recorder,
3454+
alloc_buffer_max_entries,
3455+
when,
3456+
clearHistory);
34513457
}
34523458
}
34533459

c10/cuda/CUDACachingAllocator.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ class CUDAAllocator : public Allocator {
264264
bool enabled,
265265
CreateContextFn context_recorder,
266266
size_t alloc_trace_max_entries,
267-
RecordContext when) = 0;
267+
RecordContext when,
268+
bool clearHistory) = 0;
268269
virtual void recordAnnotation(
269270
const std::vector<std::pair<std::string, std::string>>& md) {}
270271
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
@@ -414,9 +415,10 @@ inline void recordHistory(
414415
bool enabled,
415416
CreateContextFn context_recorder,
416417
size_t alloc_trace_max_entries,
417-
RecordContext when) {
418+
RecordContext when,
419+
bool clearHistory) {
418420
return get()->recordHistory(
419-
enabled, context_recorder, alloc_trace_max_entries, when);
421+
enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
420422
}
421423

422424
inline void recordAnnotation(

c10/cuda/CUDAMallocAsyncAllocator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,8 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
648648
bool enabled,
649649
CreateContextFn context_recorder,
650650
size_t alloc_trace_max_entries,
651-
RecordContext when) override {
651+
RecordContext when,
652+
bool clearHistory) override {
652653
TORCH_CHECK(
653654
false,
654655
"cudaMallocAsync does not yet support recordHistory. "

torch/_C/__init__.pyi.in

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1917,12 +1917,14 @@ def _cuda_record_memory_history_legacy(
19171917
record_context_cpp: _bool,
19181918
alloc_trace_max_entries: _int,
19191919
alloc_trace_record_context: _bool,
1920+
clear_history: _bool,
19201921
) -> None: ...
19211922
def _cuda_record_memory_history(
19221923
enabled: Optional[str],
19231924
context: Optional[str],
19241925
stacks: str,
1925-
max_entries
1926+
max_entries: _int,
1927+
clear_history: _bool,
19261928
) -> None: ...
19271929
def _cuda_isHistoryEnabled() -> _bool: ...
19281930

torch/csrc/cuda/CUDAPluggableAllocator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ void CUDAPluggableAllocator::recordHistory(
290290
bool enabled,
291291
c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,
292292
size_t alloc_trace_max_entries,
293-
c10::cuda::CUDACachingAllocator::RecordContext when) {
293+
c10::cuda::CUDACachingAllocator::RecordContext when,
294+
bool clearHistory) {
294295
TORCH_CHECK(
295296
false,
296297
"CUDAPluggableAllocator does not yet support recordHistory. "

torch/csrc/cuda/CUDAPluggableAllocator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
145145
bool enabled,
146146
c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,
147147
size_t alloc_trace_max_entries,
148-
c10::cuda::CUDACachingAllocator::RecordContext when) override;
148+
c10::cuda::CUDACachingAllocator::RecordContext when,
149+
bool clearHistory) override;
149150
void attachOutOfMemoryObserver(
150151
c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override;
151152
void attachAllocatorTraceTracker(

torch/csrc/cuda/Module.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
11281128

11291129
m.def(
11301130
"_cuda_record_memory_history_legacy",
1131-
static_cast<void (*)(bool, bool, int64_t, bool, bool)>(
1131+
static_cast<void (*)(bool, bool, int64_t, bool, bool, bool)>(
11321132
torch::cuda::_record_memory_history));
11331133

11341134
m.def(
@@ -1137,7 +1137,8 @@ static void registerCudaDeviceProperties(PyObject* module) {
11371137
std::optional<std::string>,
11381138
std::optional<std::string>,
11391139
const std::string&,
1140-
size_t)>(torch::cuda::_record_memory_history));
1140+
size_t,
1141+
bool)>(torch::cuda::_record_memory_history));
11411142

11421143
m.def("_cuda_isHistoryEnabled", []() {
11431144
return c10::cuda::CUDACachingAllocator::isHistoryEnabled();

torch/csrc/cuda/memory_snapshot.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ void _record_memory_history(
124124
bool record_context,
125125
int64_t trace_alloc_max_entries,
126126
bool trace_alloc_record_context,
127-
bool record_cpp_context) {
127+
bool record_cpp_context,
128+
bool clearHistory) {
128129
c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather;
129130
if (enabled && record_cpp_context &&
130131
(trace_alloc_record_context || record_context)) {
@@ -141,7 +142,7 @@ void _record_memory_history(
141142
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
142143
_initRecordAnnotations();
143144
c10::cuda::CUDACachingAllocator::recordHistory(
144-
enabled, recorder, trace_alloc_max_entries, when);
145+
enabled, recorder, trace_alloc_max_entries, when, clearHistory);
145146
}
146147

147148
static void checkOptionIn(
@@ -156,7 +157,8 @@ void _record_memory_history(
156157
std::optional<std::string> enabled,
157158
std::optional<std::string> context,
158159
const std::string& stacks,
159-
size_t max_entries) {
160+
size_t max_entries,
161+
bool clearHistory) {
160162
if (enabled) {
161163
checkOptionIn(
162164
*enabled,
@@ -192,7 +194,7 @@ void _record_memory_history(
192194
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
193195
_initRecordAnnotations();
194196
c10::cuda::CUDACachingAllocator::recordHistory(
195-
enabled.has_value(), recorder, max_entries, when);
197+
enabled.has_value(), recorder, max_entries, when, clearHistory);
196198
}
197199

198200
std::string _memory_snapshot_pickled() {

torch/csrc/cuda/memory_snapshot.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ TORCH_CUDA_CU_API void _record_memory_history(
1414
bool record_context = true,
1515
int64_t trace_alloc_max_entries = 1,
1616
bool trace_alloc_record_context = false,
17-
bool record_cpp_context = false);
17+
bool record_cpp_context = false,
18+
bool clearHistory = false);
1819

1920
TORCH_CUDA_CU_API void _record_memory_history(
2021
std::optional<std::string> enabled = "all",
2122
std::optional<std::string> context = "all",
2223
const std::string& stacks = "all",
23-
size_t max_entries = SIZE_MAX);
24+
size_t max_entries = SIZE_MAX,
25+
bool clearHistory = false);
2426

2527
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();
2628

torch/cuda/memory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,13 +843,15 @@ def _record_memory_history_legacy(
843843
trace_alloc_record_context=False,
844844
device: Union[Device, int] = None,
845845
record_context_cpp=False,
846+
clear_history=False,
846847
):
847848
_C._cuda_record_memory_history_legacy(
848849
enabled,
849850
record_context,
850851
trace_alloc_max_entries,
851852
trace_alloc_record_context,
852853
record_context_cpp,
854+
clear_history,
853855
)
854856

855857

@@ -904,8 +906,9 @@ def _record_memory_history_impl(
904906
stacks: str = "all",
905907
max_entries: int = sys.maxsize,
906908
device: Union[Device, int] = None,
909+
clear_history: bool = False,
907910
):
908-
_C._cuda_record_memory_history(enabled, context, stacks, max_entries)
911+
_C._cuda_record_memory_history(enabled, context, stacks, max_entries, clear_history)
909912

910913

911914
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]

0 commit comments

Comments
 (0)