@@ -48,36 +48,26 @@ using executorch::runtime::Result;
4848using executorch::runtime::Span;
4949using executorch::runtime::etensor::Tensor;
5050
51- // Structure to hold cached GPU tensor data for "keep on device" optimization
52- struct CachedGpuData {
53- void * data_ptr; // GPU memory pointer
51+ // Structure to hold a reference to a GPU tensor for "keep on device" optimization.
52+ // Owns the tensor handle - must be deleted when no longer needed.
53+ struct GpuTensorRef {
54+ AOTITensorHandle handle; // Tensor handle (owned, for later deletion)
55+ void * data_ptr; // GPU memory pointer (for D2D copy)
5456 size_t size_bytes; // Total size in bytes
55- int32_t scalar_type; // Data type
56- std::vector<int64_t > sizes; // Original shape
5757};
5858
59- // Global device cache - maps name to cached GPU data
60- // Using raw GPU pointers instead of tensor handles for format independence
61- // Note: This cache is NOT thread-safe. Callers must ensure execute() is called
62- // from a single thread.
63- static std::unordered_map<std::string, CachedGpuData> g_device_cache;
64-
65- // Helper function to clear all cached GPU memory
66- // Should be called during backend cleanup
67- static void clear_device_cache () {
68- for (auto & pair : g_device_cache) {
69- if (pair.second .data_ptr != nullptr ) {
70- cudaError_t err = cudaFree (pair.second .data_ptr );
71- if (err != cudaSuccess) {
72- ET_LOG (
73- Warning,
74- " Failed to free cached GPU memory for '%s': %s" ,
75- pair.first .c_str (),
76- cudaGetErrorString (err));
77- }
59+ // Global map of named GPU tensor references.
60+ // Note: NOT thread-safe. Callers must ensure execute() is called from a single thread.
61+ static std::unordered_map<std::string, GpuTensorRef> g_gpu_tensors;
62+
63+ // Helper to clear stored GPU tensors and free their memory
64+ static void clear_gpu_tensors () {
65+ for (auto & pair : g_gpu_tensors) {
66+ if (pair.second .handle != nullptr ) {
67+ aoti_torch_delete_tensor_object (pair.second .handle );
7868 }
7969 }
80- g_device_cache .clear ();
70+ g_gpu_tensors .clear ();
8171}
8272
8373class ET_EXPERIMENTAL CudaBackend final
@@ -354,40 +344,40 @@ class ET_EXPERIMENTAL CudaBackend final
354344
355345 gpu_inputs[i] = gpu_input_handle;
356346
357- // Check if this input slot should use cached GPU data
347+ // Check if this input slot should use a stored GPU tensor
358348 if (i == use_cache_input_slot_ && !use_cache_input_name_.empty ()) {
359- auto cache_it = g_device_cache .find (use_cache_input_name_);
360- if (cache_it != g_device_cache .end ()) {
361- const CachedGpuData& cached = cache_it ->second ;
349+ auto it = g_gpu_tensors .find (use_cache_input_name_);
350+ if (it != g_gpu_tensors .end ()) {
351+ const GpuTensorRef& ref = it ->second ;
362352 // GPU-to-GPU copy: fast DMA transfer, normalizes tensor format
363353 size_t numel = gpu_inputs[i]->numel ();
364354 size_t elem_size = gpu_inputs[i]->element_size ();
365355 size_t copy_bytes = numel * elem_size;
366356
367357 ET_CHECK_OR_RETURN_ERROR (
368- copy_bytes == cached .size_bytes ,
358+ copy_bytes == ref .size_bytes ,
369359 Internal,
370- " Cached tensor size mismatch: expected %zu bytes, got %zu" ,
360+ " Stored tensor size mismatch: expected %zu bytes, got %zu" ,
371361 copy_bytes,
372- cached .size_bytes );
362+ ref .size_bytes );
373363
374364 cudaError_t cuda_err = cudaMemcpy (
375365 gpu_inputs[i]->data_ptr (),
376- cached .data_ptr ,
366+ ref .data_ptr ,
377367 copy_bytes,
378368 cudaMemcpyDeviceToDevice);
379369
380370 ET_CHECK_OR_RETURN_ERROR (
381371 cuda_err == cudaSuccess,
382372 Internal,
383- " Failed GPU-to-GPU copy for cached input %d: %s" ,
373+ " Failed GPU-to-GPU copy for input %d: %s" ,
384374 i,
385375 cudaGetErrorString (cuda_err));
386376
387377 // Skip the CPU-to-GPU copy below
388378 continue ;
389379 }
390- // Cache miss : fall through to normal CPU-to-GPU copy
380+ // Not found : fall through to normal CPU-to-GPU copy
391381 }
392382
393383 // Copy data from CPU to GPU (normal path)
@@ -442,62 +432,27 @@ class ET_EXPERIMENTAL CudaBackend final
442432 " AOTInductorModelContainerRun failed with error code %d" ,
443433 error);
444434
445- // Cache output GPU tensor data if requested
446- // We store the raw GPU pointer for later GPU-to-GPU copy
435+ // Store reference to output GPU tensor if requested.
436+ // The tensor will be kept alive for later D2D copy to decoder inputs.
447437 if (cache_output_slot_ >= 0 && cache_output_slot_ < static_cast <int >(n_outputs) &&
448438 !cache_output_name_.empty ()) {
449439 auto * gpu_tensor = gpu_outputs[cache_output_slot_];
450440 size_t numel = gpu_tensor->numel ();
451441 size_t elem_size = gpu_tensor->element_size ();
452442 size_t size_bytes = numel * elem_size;
453443
454- // Allocate persistent GPU memory for the cache
455- void * cache_ptr = nullptr ;
456- cudaError_t alloc_err = cudaMalloc (&cache_ptr, size_bytes);
457- ET_CHECK_OR_RETURN_ERROR (
458- alloc_err == cudaSuccess,
459- Internal,
460- " Failed to allocate GPU cache memory: %s" ,
461- cudaGetErrorString (alloc_err));
462-
463- // Copy from tensor to cache (GPU-to-GPU)
464- cudaError_t copy_err = cudaMemcpy (
465- cache_ptr,
466- gpu_tensor->data_ptr (),
467- size_bytes,
468- cudaMemcpyDeviceToDevice);
469- if (copy_err != cudaSuccess) {
470- // Free allocated memory before returning error
471- cudaFree (cache_ptr);
472- ET_LOG (
473- Error,
474- " Failed to copy output to GPU cache: %s" ,
475- cudaGetErrorString (copy_err));
476- return Error::Internal;
477- }
478-
479- // Free old cache if exists
480- auto old_it = g_device_cache.find (cache_output_name_);
481- if (old_it != g_device_cache.end ()) {
482- cudaError_t free_err = cudaFree (old_it->second .data_ptr );
483- if (free_err != cudaSuccess) {
484- ET_LOG (
485- Warning,
486- " Failed to free old cached GPU memory for '%s': %s" ,
487- cache_output_name_.c_str (),
488- cudaGetErrorString (free_err));
489- }
490- g_device_cache.erase (old_it);
444+ // Delete old tensor if overwriting
445+ auto old_it = g_gpu_tensors.find (cache_output_name_);
446+ if (old_it != g_gpu_tensors.end () && old_it->second .handle != nullptr ) {
447+ aoti_torch_delete_tensor_object (old_it->second .handle );
491448 }
492449
493- // Store in cache
494- CachedGpuData cached;
495- cached.data_ptr = cache_ptr;
496- cached.size_bytes = size_bytes;
497- cached.scalar_type = static_cast <int32_t >(gpu_tensor->scalar_type ());
498- auto sizes = gpu_tensor->sizes ();
499- cached.sizes .assign (sizes.begin (), sizes.end ());
500- g_device_cache[cache_output_name_] = std::move (cached);
450+ // Store tensor reference (we now own this tensor)
451+ GpuTensorRef ref;
452+ ref.handle = gpu_tensor;
453+ ref.data_ptr = gpu_tensor->data_ptr ();
454+ ref.size_bytes = size_bytes;
455+ g_gpu_tensors[cache_output_name_] = ref;
501456
502457 // Reset cache_output settings after caching
503458 cache_output_slot_ = -1 ;
@@ -523,6 +478,26 @@ class ET_EXPERIMENTAL CudaBackend final
523478 // reuse cached encoder output. The caller should explicitly clear
524479 // these settings using the "clear_cache_input" option when done.
525480
481+ // Cleanup: delete GPU tensors to avoid memory leak across execute() calls.
482+ // Input tensors are no longer needed after AOTI execution.
483+ for (size_t i = 0 ; i < n_inputs; i++) {
484+ aoti_torch_delete_tensor_object (gpu_inputs[i]);
485+ }
486+ // Output tensors are no longer needed after copying to CPU,
487+ // EXCEPT for tensors stored in g_gpu_tensors (for later D2D copy).
488+ for (size_t i = 0 ; i < n_outputs; i++) {
489+ bool is_stored = false ;
490+ for (const auto & pair : g_gpu_tensors) {
491+ if (pair.second .handle == gpu_outputs[i]) {
492+ is_stored = true ;
493+ break ;
494+ }
495+ }
496+ if (!is_stored) {
497+ aoti_torch_delete_tensor_object (gpu_outputs[i]);
498+ }
499+ }
500+
526501 return Error::Ok;
527502 }
528503
@@ -532,8 +507,8 @@ class ET_EXPERIMENTAL CudaBackend final
532507 }
533508 AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
534509
535- // Clear all cached GPU memory
536- clear_device_cache ();
510+ // Delete stored GPU tensors
511+ clear_gpu_tensors ();
537512
538513 // Destroy the CUDA stream if it exists
539514 if (handle->cuda_stream != nullptr ) {
0 commit comments