Skip to content

Commit b31576c

Browse files
committed
RAII
1 parent e535aee commit b31576c

File tree

1 file changed

+52
-25
lines changed

1 file changed

+52
-25
lines changed

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <executorch/runtime/core/error.h>
1313
#include <executorch/runtime/core/evalue.h>
1414
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
15-
#include <climits>
1615
#include <cstdio>
1716

1817
#include <filesystem>
@@ -305,8 +304,41 @@ class ET_EXPERIMENTAL CudaBackend final
305304
std::vector<AOTITensorHandle> gpu_outputs(
306305
n_outputs); // GPU tensors for kernel output
307306

307+
// RAII helper to ensure GPU tensors are cleaned up on all exit paths.
308+
// Prevents memory leaks when errors occur during execute().
309+
struct TensorCleanup {
310+
std::vector<AOTITensorHandle>& inputs;
311+
std::vector<AOTITensorHandle>& outputs;
312+
const std::unordered_map<std::string, GpuTensorRef>& stored_tensors;
313+
314+
~TensorCleanup() {
315+
// Clean up input tensors
316+
for (auto* handle : inputs) {
317+
if (handle != nullptr) {
318+
aoti_torch_delete_tensor_object(handle);
319+
}
320+
}
321+
// Clean up output tensors, except those that are stored
322+
for (auto* handle : outputs) {
323+
if (handle != nullptr) {
324+
bool is_stored = false;
325+
for (const auto& pair : stored_tensors) {
326+
if (pair.second.handle == handle) {
327+
is_stored = true;
328+
break;
329+
}
330+
}
331+
if (!is_stored) {
332+
aoti_torch_delete_tensor_object(handle);
333+
}
334+
}
335+
}
336+
}
337+
};
338+
TensorCleanup cleanup{gpu_inputs, gpu_outputs, gpu_tensors_};
339+
308340
// Process input tensors: ExecuTorch provides CPU tensors, create GPU
309-
// copies. For cached inputs, use GPU-to-GPU copy instead of CPU-to-GPU.
341+
// copies. For stored inputs, use GPU-to-GPU copy instead of CPU-to-GPU.
310342
for (int i = 0; i < n_inputs; i++) {
311343
// Get tensor dimensions and properties from ExecuTorch CPU tensor
312344
auto cpu_tensor = &(args[i]->toTensor());
@@ -334,7 +366,9 @@ class ET_EXPERIMENTAL CudaBackend final
334366

335367
gpu_inputs[i] = gpu_input_handle;
336368

337-
// Check if this input matches a stored GPU tensor (by size)
369+
// Check if this input matches a stored GPU tensor (by size).
370+
// Note: Size-based matching assumes only one input will match. If multiple
371+
// inputs have the same byte size as the stored tensor, the first match wins.
338372
if (!use_stored_input_name_.empty()) {
339373
auto it = gpu_tensors_.find(use_stored_input_name_);
340374
if (it != gpu_tensors_.end()) {
@@ -345,6 +379,13 @@ class ET_EXPERIMENTAL CudaBackend final
345379

346380
// Match by size: use stored tensor if sizes match
347381
if (copy_bytes == ref.size_bytes) {
382+
ET_LOG(
383+
Debug,
384+
"Using stored tensor '%s' for input %d (%zu bytes, D2D copy)",
385+
use_stored_input_name_.c_str(),
386+
i,
387+
copy_bytes);
388+
348389
// GPU-to-GPU copy: fast DMA transfer, normalizes tensor format
349390
cudaError_t cuda_err = cudaMemcpy(
350391
gpu_inputs[i]->data_ptr(),
@@ -418,9 +459,14 @@ class ET_EXPERIMENTAL CudaBackend final
418459
error);
419460

420461
// Store reference to output GPU tensor if requested.
421-
// Always uses gpu_outputs[0] (encoder has single output).
422462
// The tensor will be kept alive for later D2D copy to decoder inputs.
423-
if (!store_output_name_.empty() && n_outputs > 0) {
463+
if (!store_output_name_.empty()) {
464+
ET_CHECK_OR_RETURN_ERROR(
465+
n_outputs == 1,
466+
InvalidArgument,
467+
"store_output only supports single-output methods, got %zu outputs",
468+
n_outputs);
469+
424470
auto* gpu_tensor = gpu_outputs[0];
425471
size_t numel = gpu_tensor->numel();
426472
size_t elem_size = gpu_tensor->element_size();
@@ -462,6 +508,7 @@ class ET_EXPERIMENTAL CudaBackend final
462508
}
463509

464510
// Memory management notes:
511+
// - GPU tensor cleanup is handled by TensorCleanup RAII guard above.
465512
// - use_stored_input setting persists across execute() calls to support
466513
// decoder loops that reuse the stored encoder output.
467514
// - Stored GPU tensors (in gpu_tensors_) remain in memory until:
@@ -470,26 +517,6 @@ class ET_EXPERIMENTAL CudaBackend final
470517
// - The "reset_stored_input" option only resets the input name setting,
471518
// NOT the stored GPU tensors themselves.
472519

473-
// Cleanup: delete GPU tensors to avoid memory leak across execute() calls.
474-
// Input tensors are no longer needed after AOTI execution.
475-
for (size_t i = 0; i < n_inputs; i++) {
476-
aoti_torch_delete_tensor_object(gpu_inputs[i]);
477-
}
478-
// Output tensors are no longer needed after copying to CPU,
479-
// EXCEPT for tensors stored in gpu_tensors_ (for later D2D copy).
480-
for (size_t i = 0; i < n_outputs; i++) {
481-
bool is_stored = false;
482-
for (const auto& pair : gpu_tensors_) {
483-
if (pair.second.handle == gpu_outputs[i]) {
484-
is_stored = true;
485-
break;
486-
}
487-
}
488-
if (!is_stored) {
489-
aoti_torch_delete_tensor_object(gpu_outputs[i]);
490-
}
491-
}
492-
493520
return Error::Ok;
494521
}
495522

0 commit comments

Comments
 (0)