-
Notifications
You must be signed in to change notification settings - Fork 754
[CUDA]: GPU Device Caching for Encoder Output in CUDA Backend #16060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
37c47d4
5da254e
ffbfbe7
03d27e7
e535aee
9e1a3cc
3d0b621
bc560b3
137e6da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
|
|
||
| #include <cuda_runtime.h> | ||
| #include <executorch/runtime/backend/interface.h> | ||
| #include <executorch/runtime/backend/options.h> | ||
| #include <executorch/runtime/core/error.h> | ||
| #include <executorch/runtime/core/evalue.h> | ||
| #include <executorch/runtime/core/exec_aten/util/tensor_util.h> | ||
|
|
@@ -16,6 +17,7 @@ | |
| #include <filesystem> | ||
| #include <fstream> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| // Include our shim layer headers | ||
|
|
@@ -46,9 +48,88 @@ using executorch::runtime::Result; | |
| using executorch::runtime::Span; | ||
| using executorch::runtime::etensor::Tensor; | ||
|
|
||
| // Structure to hold a reference to a GPU tensor for "keep on device" | ||
| // optimization. Owns the tensor handle - must be deleted when no longer needed. | ||
| struct GpuTensorRef { | ||
| AOTITensorHandle handle; // Tensor handle (owned, for later deletion) | ||
| void* data_ptr; // GPU memory pointer (for D2D copy) | ||
| size_t size_bytes; // Total size in bytes | ||
| }; | ||
|
|
||
| class ET_EXPERIMENTAL CudaBackend final | ||
| : public ::executorch::runtime::BackendInterface { | ||
| private: | ||
| // ============================================================================ | ||
| // GPU Tensor Storage for D2D Copy Optimization | ||
| // ============================================================================ | ||
| // | ||
| // This backend supports storing GPU tensors between execute() calls to enable | ||
| // device-to-device (D2D) copies instead of slower host-to-device (H2D) | ||
| // copies. This is useful for encoder-decoder models where the encoder output | ||
| // is reused across many decoder iterations. | ||
| // | ||
| // SUPPORTED OPTIONS (via set_option): | ||
| // | ||
| // "store_output" (string): Store the output tensor under this name after | ||
| // the next execute() call. The tensor remains on GPU until cleared. | ||
| // Only supports single-output methods. | ||
| // Example: opts.set_option("store_output", "encoder_output"); | ||
| // | ||
| // "use_stored_input" (string): For inputs matching the stored tensor's | ||
| // size, | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // use D2D copy from the stored tensor instead of H2D copy from CPU. | ||
| // This setting persists across execute() calls until reset. | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Example: opts.set_option("use_stored_input", "encoder_output"); | ||
| // | ||
| // "reset_stored_input" (bool): Clear the use_stored_input setting. | ||
| // Does NOT delete the stored tensor - only stops using it for D2D. | ||
| // Example: opts.set_option("reset_stored_input", true); | ||
| // | ||
| // "clear_stored_tensor" (string): Delete the named tensor from storage, | ||
| // freeing GPU memory. Use after decoder loop completes. | ||
| // Example: opts.set_option("clear_stored_tensor", "encoder_output"); | ||
| // | ||
| // TYPICAL USAGE PATTERN (encoder-decoder model): | ||
| // | ||
| // 1. Before encoder: set_option("store_output", "encoder_output") | ||
| // 2. Execute encoder (output is stored on GPU) | ||
| // 3. Before decoder loop: set_option("use_stored_input", "encoder_output") | ||
| // 4. Execute decoder N times (D2D copies for encoder output input) | ||
| // 5. After decoder loop: | ||
| // set_option("reset_stored_input", true) | ||
| // set_option("clear_stored_tensor", "encoder_output") | ||
| // | ||
| // ============================================================================ | ||
|
Comment on lines
+92
to
+102
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trying to understand the intention, is it trying to use backend option to have let the method encode/decode share the output memory? In an ideal word, if encode/decode methods can share memory planning, does it mean we don't have to use this?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Its trying to avoid cpu->gpu copies. If we had device tensor we wouldnt need this, but its wip and perf here is time sensitive so Mergen is hacking around it until its properly fixed upstream
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, that seems fine to me. Maybe worth adding this as part of the comment because I can't tell from the PR |
||
|
|
||
| // Storage control options (set via set_option before execute) | ||
| mutable std::string | ||
| store_output_name_; // Name to store output under (empty = none) | ||
| mutable std::string | ||
| use_stored_input_name_; // Name of stored tensor to use (empty = none) | ||
|
|
||
| // Per-instance map of named GPU tensor references. | ||
| // Mutable because execute() is const but needs to modify this. | ||
| // | ||
| // LIFETIME CONTRACT: | ||
| // - Stored tensors are valid until overwritten or destroy() is called. | ||
| // - Caller must ensure the producing execute() call (e.g., encoder) completes | ||
| // before any consuming execute() call (e.g., decoder) begins. | ||
| // - Caller must not call destroy() while execute() is in progress. | ||
| // - Overwriting a tensor (same name) deletes the old tensor immediately, | ||
| // so caller must ensure no concurrent execute() is using it. | ||
| mutable std::unordered_map<std::string, GpuTensorRef> gpu_tensors_; | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // Helper to clear stored GPU tensors and free their memory. | ||
| // Only call when no execute() is in progress. | ||
| void clear_gpu_tensors() const { | ||
| for (auto& pair : gpu_tensors_) { | ||
| if (pair.second.handle != nullptr) { | ||
| aoti_torch_delete_tensor_object(pair.second.handle); | ||
| } | ||
| } | ||
| gpu_tensors_.clear(); | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Error load_function_pointers_into_handle( | ||
| void* so_handle, | ||
| AOTIDelegateHandle* handle) const { | ||
|
|
@@ -91,6 +172,70 @@ class ET_EXPERIMENTAL CudaBackend final | |
| return 1; | ||
| } | ||
|
|
||
| Error set_option( | ||
| __ET_UNUSED executorch::runtime::BackendOptionContext& context, | ||
| const executorch::runtime::Span<executorch::runtime::BackendOption>& | ||
| backend_options) override { | ||
| for (size_t i = 0; i < backend_options.size(); i++) { | ||
| const auto& option = backend_options[i]; | ||
| // Handle store_output: expects a string name (e.g., "encoder_output") | ||
| if (strcmp(option.key, "store_output") == 0) { | ||
| if (auto* arr = std::get_if< | ||
| std::array<char, executorch::runtime::kMaxOptionValueLength>>( | ||
| &option.value)) { | ||
| store_output_name_ = std::string(arr->data()); | ||
| } else { | ||
| ET_LOG(Error, "store_output option expects a string value"); | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return Error::InvalidArgument; | ||
| } | ||
| } | ||
| // Handle use_stored_input: expects a string name (e.g., "encoder_output") | ||
| else if (strcmp(option.key, "use_stored_input") == 0) { | ||
| if (auto* arr = std::get_if< | ||
| std::array<char, executorch::runtime::kMaxOptionValueLength>>( | ||
| &option.value)) { | ||
| use_stored_input_name_ = std::string(arr->data()); | ||
| } else { | ||
| ET_LOG(Error, "use_stored_input option expects a string value"); | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return Error::InvalidArgument; | ||
| } | ||
| } | ||
| // Handle reset_stored_input: expects a boolean value | ||
| // Note: This only resets the name setting. The stored GPU tensor | ||
| // remains in memory until overwritten or destroy() is called. | ||
| else if (strcmp(option.key, "reset_stored_input") == 0) { | ||
| if (auto* val = std::get_if<bool>(&option.value)) { | ||
| if (*val) { | ||
| use_stored_input_name_.clear(); | ||
| } | ||
| } else { | ||
| ET_LOG(Error, "reset_stored_input option expects a boolean value"); | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return Error::InvalidArgument; | ||
| } | ||
| } | ||
| // Handle clear_stored_tensor: expects a string name | ||
| // Deletes the named GPU tensor from storage, freeing GPU memory. | ||
| else if (strcmp(option.key, "clear_stored_tensor") == 0) { | ||
| if (auto* arr = std::get_if< | ||
| std::array<char, executorch::runtime::kMaxOptionValueLength>>( | ||
| &option.value)) { | ||
| std::string name(arr->data()); | ||
| auto it = gpu_tensors_.find(name); | ||
| if (it != gpu_tensors_.end()) { | ||
| if (it->second.handle != nullptr) { | ||
| aoti_torch_delete_tensor_object(it->second.handle); | ||
| } | ||
| gpu_tensors_.erase(it); | ||
| } | ||
| } else { | ||
| ET_LOG(Error, "clear_stored_tensor option expects a string value"); | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return Error::InvalidArgument; | ||
| } | ||
| } | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| return Error::Ok; | ||
| } | ||
|
|
||
| // Once per loaded binary blob | ||
| Result<DelegateHandle*> init( | ||
| BackendInitContext& context, | ||
|
|
@@ -222,15 +367,52 @@ class ET_EXPERIMENTAL CudaBackend final | |
| std::vector<AOTITensorHandle> gpu_outputs( | ||
| n_outputs); // GPU tensors for kernel output | ||
|
|
||
| // RAII helper to ensure GPU tensors are cleaned up on all exit paths. | ||
| // Prevents memory leaks when errors occur during execute(). | ||
| struct TensorCleanup { | ||
| std::vector<AOTITensorHandle>& inputs; | ||
| std::vector<AOTITensorHandle>& outputs; | ||
| const std::unordered_map<std::string, GpuTensorRef>& stored_tensors; | ||
|
|
||
| ~TensorCleanup() { | ||
| // Clean up input tensors | ||
| for (auto* handle : inputs) { | ||
| if (handle != nullptr) { | ||
| aoti_torch_delete_tensor_object(handle); | ||
| } | ||
| } | ||
| // Clean up output tensors, except those that are stored | ||
| for (auto* handle : outputs) { | ||
| if (handle != nullptr) { | ||
| bool is_stored = false; | ||
| for (const auto& pair : stored_tensors) { | ||
| if (pair.second.handle == handle) { | ||
| is_stored = true; | ||
| break; | ||
| } | ||
| } | ||
| if (!is_stored) { | ||
| aoti_torch_delete_tensor_object(handle); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| }; | ||
| TensorCleanup cleanup{gpu_inputs, gpu_outputs, gpu_tensors_}; | ||
|
|
||
| // Track which input index was matched for D2D copy (for duplicate | ||
| // detection) | ||
| ssize_t matched_input_idx = -1; | ||
|
|
||
| // Process input tensors: ExecuTorch provides CPU tensors, create GPU | ||
| // copies | ||
| for (int i = 0; i < n_inputs; i++) { | ||
| // copies. For stored inputs, use GPU-to-GPU copy instead of CPU-to-GPU. | ||
| for (size_t i = 0; i < n_inputs; i++) { | ||
| // Get tensor dimensions and properties from ExecuTorch CPU tensor | ||
| auto cpu_tensor = &(args[i]->toTensor()); | ||
| auto sizes = cpu_tensor->sizes(); | ||
| auto scalar_type = cpu_tensor->scalar_type(); | ||
|
|
||
| // Create GPU tensor with same shape | ||
| // Create GPU tensor with same shape (always needed for AOTI format) | ||
| std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end()); | ||
|
|
||
| AOTITensorHandle gpu_input_handle; | ||
|
|
@@ -246,21 +428,75 @@ class ET_EXPERIMENTAL CudaBackend final | |
| ET_CHECK_OR_RETURN_ERROR( | ||
| create_err == Error::Ok, | ||
| Internal, | ||
| "Failed to create GPU tensor for input %d", | ||
| "Failed to create GPU tensor for input %zu", | ||
| i); | ||
|
|
||
| gpu_inputs[i] = gpu_input_handle; | ||
|
|
||
| // Copy data from CPU to GPU | ||
| // Check if this input matches a stored GPU tensor (by size). | ||
| if (!use_stored_input_name_.empty()) { | ||
| auto it = gpu_tensors_.find(use_stored_input_name_); | ||
| if (it != gpu_tensors_.end()) { | ||
| const GpuTensorRef& ref = it->second; | ||
| size_t numel = gpu_inputs[i]->numel(); | ||
| size_t elem_size = gpu_inputs[i]->element_size(); | ||
| size_t copy_bytes = numel * elem_size; | ||
|
|
||
| // Match by size: use stored tensor if sizes match | ||
| if (copy_bytes == ref.size_bytes) { | ||
| if (matched_input_idx >= 0) { | ||
| // Another input already matched - warn about ambiguity | ||
| ET_LOG( | ||
| Error, | ||
| "Multiple inputs match stored tensor '%s' size (%zu bytes): " | ||
| "input %zd was used, input %zu also matches. " | ||
| "Consider using unique tensor sizes or a different matching strategy.", | ||
| use_stored_input_name_.c_str(), | ||
| copy_bytes, | ||
| matched_input_idx, | ||
| i); | ||
| } else { | ||
| // First match - perform D2D copy | ||
| matched_input_idx = static_cast<ssize_t>(i); | ||
|
|
||
| ET_LOG( | ||
| Debug, | ||
| "Using stored tensor '%s' for input %zu (%zu bytes, D2D copy)", | ||
| use_stored_input_name_.c_str(), | ||
| i, | ||
| copy_bytes); | ||
|
|
||
| // GPU-to-GPU copy: fast DMA transfer, normalizes tensor format | ||
| cudaError_t cuda_err = cudaMemcpy( | ||
| gpu_inputs[i]->data_ptr(), | ||
| ref.data_ptr, | ||
| copy_bytes, | ||
| cudaMemcpyDeviceToDevice); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| cuda_err == cudaSuccess, | ||
| Internal, | ||
| "Failed GPU-to-GPU copy for input %zu: %s", | ||
| i, | ||
| cudaGetErrorString(cuda_err)); | ||
|
|
||
| // Skip the CPU-to-GPU copy below | ||
| continue; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Copy data from CPU to GPU (normal path) | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok, | ||
| Internal, | ||
| "Failed to copy input %d from CPU to GPU", | ||
| "Failed to copy input %zu from CPU to GPU", | ||
| i); | ||
| } | ||
| // Process output tensors: create GPU counterparts for ExecuTorch CPU | ||
| // tensors | ||
| for (int i = 0; i < n_outputs; i++) { | ||
| for (size_t i = 0; i < n_outputs; i++) { | ||
| // Get output tensor dimensions from ExecuTorch CPU tensor | ||
| auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); | ||
| auto sizes = cpu_output_tensor->sizes(); | ||
|
|
@@ -282,7 +518,7 @@ class ET_EXPERIMENTAL CudaBackend final | |
| ET_CHECK_OR_RETURN_ERROR( | ||
| create_err == Error::Ok, | ||
| Internal, | ||
| "Failed to create GPU tensor for output %d", | ||
| "Failed to create GPU tensor for output %zu", | ||
| i); | ||
|
|
||
| gpu_outputs[i] = gpu_output_handle; | ||
|
|
@@ -303,20 +539,65 @@ class ET_EXPERIMENTAL CudaBackend final | |
| "AOTInductorModelContainerRun failed with error code %d", | ||
| error); | ||
|
|
||
| // Store reference to output GPU tensor if requested. | ||
| // The tensor will be kept alive for later D2D copy to decoder inputs. | ||
| if (!store_output_name_.empty()) { | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| n_outputs == 1, | ||
| InvalidArgument, | ||
| "store_output only supports single-output methods, got %zu outputs", | ||
| n_outputs); | ||
|
|
||
| auto* gpu_tensor = gpu_outputs[0]; | ||
| size_t numel = gpu_tensor->numel(); | ||
| size_t elem_size = gpu_tensor->element_size(); | ||
| size_t size_bytes = numel * elem_size; | ||
|
|
||
| // Delete old tensor if overwriting (erase first to prevent double-free) | ||
| auto old_it = gpu_tensors_.find(store_output_name_); | ||
| if (old_it != gpu_tensors_.end()) { | ||
| AOTITensorHandle old_handle = old_it->second.handle; | ||
| gpu_tensors_.erase(old_it); // Remove from map before deleting | ||
| if (old_handle != nullptr) { | ||
| aoti_torch_delete_tensor_object(old_handle); | ||
| } | ||
| } | ||
|
|
||
| // Store tensor reference (we now own this tensor) | ||
| GpuTensorRef ref; | ||
| ref.handle = gpu_tensor; | ||
| ref.data_ptr = gpu_tensor->data_ptr(); | ||
| ref.size_bytes = size_bytes; | ||
| gpu_tensors_[store_output_name_] = ref; | ||
|
|
||
| // Reset store_output name after storing | ||
| store_output_name_.clear(); | ||
| } | ||
|
|
||
| // Copy GPU output results back to CPU output tensors | ||
| for (int i = 0; i < n_outputs; i++) { | ||
| for (size_t i = 0; i < n_outputs; i++) { | ||
| auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); | ||
| // For DYNAMIC_BOUND tensors we try to resize | ||
| ET_CHECK_OK_OR_RETURN_ERROR( | ||
| resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()), | ||
| "Error resizing tensor at output index %d", | ||
| "Error resizing tensor at output index %zu", | ||
| i); | ||
| ET_CHECK_OK_OR_RETURN_ERROR( | ||
| aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0), | ||
| "Failed to copy GPU output %d back to CPU", | ||
| "Failed to copy GPU output %zu back to CPU", | ||
| i); | ||
| } | ||
|
|
||
| // Memory management notes: | ||
| // - GPU tensor cleanup is handled by TensorCleanup RAII guard above. | ||
| // - use_stored_input setting persists across execute() calls to support | ||
| // decoder loops that reuse the stored encoder output. | ||
| // - Stored GPU tensors (in gpu_tensors_) remain in memory until: | ||
| // (a) overwritten by a new tensor with the same name, or | ||
| // (b) destroy() is called, which frees all stored tensors. | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // - The "reset_stored_input" option only resets the input name setting, | ||
| // NOT the stored GPU tensors themselves. | ||
|
|
||
| return Error::Ok; | ||
| } | ||
|
|
||
|
|
@@ -326,6 +607,9 @@ class ET_EXPERIMENTAL CudaBackend final | |
| } | ||
| AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; | ||
|
|
||
| // Delete stored GPU tensors | ||
| clear_gpu_tensors(); | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // Destroy the CUDA stream if it exists | ||
| if (handle->cuda_stream != nullptr) { | ||
| cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious why do we still need to copy? Can you just make_tensor using the GPU data pointer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I tried not copying initially and it was segfaulting. Because they're completely two different graphs, the output from the first graph and input from the second graph had different underlying layout assumptions, so had to explicitly copy.