Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 187 additions & 3 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <cuda_runtime.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
Expand All @@ -16,6 +17,7 @@
#include <filesystem>
#include <fstream>
#include <string>
#include <unordered_map>
#include <vector>

// Include our shim layer headers
Expand Down Expand Up @@ -46,9 +48,37 @@ using executorch::runtime::Result;
using executorch::runtime::Span;
using executorch::runtime::etensor::Tensor;

// Structure to hold a reference to a GPU tensor for "keep on device" optimization.
// Owns the tensor handle - must be deleted when no longer needed.
struct GpuTensorRef {
AOTITensorHandle handle; // Tensor handle (owned, for later deletion)
void* data_ptr; // GPU memory pointer (for D2D copy)
size_t size_bytes; // Total size in bytes
};

// Global map of named GPU tensor references.
// Note: NOT thread-safe. Callers must ensure execute() is called from a single thread.
static std::unordered_map<std::string, GpuTensorRef> g_gpu_tensors;

// Helper to clear stored GPU tensors and free their memory
static void clear_gpu_tensors() {
for (auto& pair : g_gpu_tensors) {
if (pair.second.handle != nullptr) {
aoti_torch_delete_tensor_object(pair.second.handle);
}
}
g_gpu_tensors.clear();
}

class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:
// Cache control options (set via set_option before execute)
mutable int cache_output_slot_ = -1; // Which output slot to cache (-1 = none)
mutable std::string cache_output_name_; // Name to cache output under
mutable int use_cache_input_slot_ = -1; // Which input slot to use cache for (-1 = none)
mutable std::string use_cache_input_name_; // Name of cached tensor to use

Error load_function_pointers_into_handle(
void* so_handle,
AOTIDelegateHandle* handle) const {
Expand Down Expand Up @@ -91,6 +121,69 @@ class ET_EXPERIMENTAL CudaBackend final
return 1;
}

Error set_option(
__ET_UNUSED executorch::runtime::BackendOptionContext& context,
const executorch::runtime::Span<executorch::runtime::BackendOption>&
backend_options) override {
for (size_t i = 0; i < backend_options.size(); i++) {
const auto& option = backend_options[i];
// Handle cache_output: "slot:name" format (e.g., "0:encoder_output")
if (strcmp(option.key, "cache_output") == 0) {
if (auto* arr = std::get_if<
std::array<char, executorch::runtime::kMaxOptionValueLength>>(
&option.value)) {
std::string val(arr->data());
auto colon_pos = val.find(':');
if (colon_pos != std::string::npos) {
try {
cache_output_slot_ = std::stoi(val.substr(0, colon_pos));
cache_output_name_ = val.substr(colon_pos + 1);
} catch (const std::exception& e) {
ET_LOG(
Error,
"Invalid cache_output format '%s': %s",
val.c_str(),
e.what());
return Error::InvalidArgument;
}
}
}
}
// Handle use_cache_input: "slot:name" format (e.g., "1:encoder_output")
else if (strcmp(option.key, "use_cache_input") == 0) {
if (auto* arr = std::get_if<
std::array<char, executorch::runtime::kMaxOptionValueLength>>(
&option.value)) {
std::string val(arr->data());
auto colon_pos = val.find(':');
if (colon_pos != std::string::npos) {
try {
use_cache_input_slot_ = std::stoi(val.substr(0, colon_pos));
use_cache_input_name_ = val.substr(colon_pos + 1);
} catch (const std::exception& e) {
ET_LOG(
Error,
"Invalid use_cache_input format '%s': %s",
val.c_str(),
e.what());
return Error::InvalidArgument;
}
}
}
}
// Handle clear_cache_input: reset input cache settings
else if (strcmp(option.key, "clear_cache_input") == 0) {
if (auto* val = std::get_if<bool>(&option.value)) {
if (*val) {
use_cache_input_slot_ = -1;
use_cache_input_name_.clear();
}
}
}
}
return Error::Ok;
}

// Once per loaded binary blob
Result<DelegateHandle*> init(
BackendInitContext& context,
Expand Down Expand Up @@ -223,14 +316,14 @@ class ET_EXPERIMENTAL CudaBackend final
n_outputs); // GPU tensors for kernel output

// Process input tensors: ExecuTorch provides CPU tensors, create GPU
// copies
// copies. For cached inputs, use GPU-to-GPU copy instead of CPU-to-GPU.
for (int i = 0; i < n_inputs; i++) {
// Get tensor dimensions and properties from ExecuTorch CPU tensor
auto cpu_tensor = &(args[i]->toTensor());
auto sizes = cpu_tensor->sizes();
auto scalar_type = cpu_tensor->scalar_type();

// Create GPU tensor with same shape
// Create GPU tensor with same shape (always needed for AOTI format)
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());

AOTITensorHandle gpu_input_handle;
Expand All @@ -251,7 +344,43 @@ class ET_EXPERIMENTAL CudaBackend final

gpu_inputs[i] = gpu_input_handle;

// Copy data from CPU to GPU
// Check if this input slot should use a stored GPU tensor
if (i == use_cache_input_slot_ && !use_cache_input_name_.empty()) {
auto it = g_gpu_tensors.find(use_cache_input_name_);
if (it != g_gpu_tensors.end()) {
const GpuTensorRef& ref = it->second;
// GPU-to-GPU copy: fast DMA transfer, normalizes tensor format
size_t numel = gpu_inputs[i]->numel();
size_t elem_size = gpu_inputs[i]->element_size();
size_t copy_bytes = numel * elem_size;

ET_CHECK_OR_RETURN_ERROR(
copy_bytes == ref.size_bytes,
Internal,
"Stored tensor size mismatch: expected %zu bytes, got %zu",
copy_bytes,
ref.size_bytes);

cudaError_t cuda_err = cudaMemcpy(
gpu_inputs[i]->data_ptr(),
ref.data_ptr,
copy_bytes,
cudaMemcpyDeviceToDevice);

ET_CHECK_OR_RETURN_ERROR(
cuda_err == cudaSuccess,
Internal,
"Failed GPU-to-GPU copy for input %d: %s",
i,
cudaGetErrorString(cuda_err));

// Skip the CPU-to-GPU copy below
continue;
}
// Not found: fall through to normal CPU-to-GPU copy
}

// Copy data from CPU to GPU (normal path)
ET_CHECK_OR_RETURN_ERROR(
aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok,
Internal,
Expand Down Expand Up @@ -303,6 +432,33 @@ class ET_EXPERIMENTAL CudaBackend final
"AOTInductorModelContainerRun failed with error code %d",
error);

// Store reference to output GPU tensor if requested.
// The tensor will be kept alive for later D2D copy to decoder inputs.
if (cache_output_slot_ >= 0 && cache_output_slot_ < static_cast<int>(n_outputs) &&
!cache_output_name_.empty()) {
auto* gpu_tensor = gpu_outputs[cache_output_slot_];
size_t numel = gpu_tensor->numel();
size_t elem_size = gpu_tensor->element_size();
size_t size_bytes = numel * elem_size;

// Delete old tensor if overwriting
auto old_it = g_gpu_tensors.find(cache_output_name_);
if (old_it != g_gpu_tensors.end() && old_it->second.handle != nullptr) {
aoti_torch_delete_tensor_object(old_it->second.handle);
}

// Store tensor reference (we now own this tensor)
GpuTensorRef ref;
ref.handle = gpu_tensor;
ref.data_ptr = gpu_tensor->data_ptr();
ref.size_bytes = size_bytes;
g_gpu_tensors[cache_output_name_] = ref;

// Reset cache_output settings after caching
cache_output_slot_ = -1;
cache_output_name_.clear();
}

// Copy GPU output results back to CPU output tensors
for (int i = 0; i < n_outputs; i++) {
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
Expand All @@ -317,6 +473,31 @@ class ET_EXPERIMENTAL CudaBackend final
i);
}

// Note: use_cache_input settings are intentionally NOT reset here.
// They persist across execute() calls to support decoder loops that
// reuse cached encoder output. The caller should explicitly clear
// these settings using the "clear_cache_input" option when done.

// Cleanup: delete GPU tensors to avoid memory leak across execute() calls.
// Input tensors are no longer needed after AOTI execution.
for (size_t i = 0; i < n_inputs; i++) {
aoti_torch_delete_tensor_object(gpu_inputs[i]);
}
// Output tensors are no longer needed after copying to CPU,
// EXCEPT for tensors stored in g_gpu_tensors (for later D2D copy).
for (size_t i = 0; i < n_outputs; i++) {
bool is_stored = false;
for (const auto& pair : g_gpu_tensors) {
if (pair.second.handle == gpu_outputs[i]) {
is_stored = true;
break;
}
}
if (!is_stored) {
aoti_torch_delete_tensor_object(gpu_outputs[i]);
}
}

return Error::Ok;
}

Expand All @@ -326,6 +507,9 @@ class ET_EXPERIMENTAL CudaBackend final
}
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

// Delete stored GPU tensors
clear_gpu_tensors();

// Destroy the CUDA stream if it exists
if (handle->cuda_stream != nullptr) {
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
Expand Down
42 changes: 42 additions & 0 deletions extension/asr/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <executorch/extension/llm/runner/util.h>
#include <executorch/extension/llm/sampler/util.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>
Expand Down Expand Up @@ -196,6 +198,17 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
}
}

// Tell CUDA backend to cache encoder output (slot 0) as "encoder_output"
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("cache_output", "0:encoder_output");
auto err =
::executorch::runtime::set_option("CudaBackend", opts.view());
if (err != ::executorch::runtime::Error::Ok) {
ET_LOG(Info, "Failed to set cache_output option (backend may not support caching)");
}
}

auto encoder_result =
module_->execute(kEncoderMethodName, preprocessed_features);
ET_CHECK_OK_OR_RETURN_ERROR(encoder_result.error());
Expand Down Expand Up @@ -249,6 +262,26 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
decoder_inputs.emplace_back(decoder_input_ptr);
decoder_inputs.emplace_back(encoder_output_ptr);
decoder_inputs.emplace_back(cache_position_ptr);

// Tell CUDA backend to use cached encoder output for decoder input slot 2.
//
// Why slot 2? The AOTI-compiled decoder receives inputs in a different order
// than we pass them in decoder_inputs above. The AOTI input order was
// determined empirically by examining tensor shapes during execution.
//
// The "2:encoder_output" format tells the backend to use the stored GPU
// tensor named "encoder_output" for AOTI input slot 2. This avoids redundant
// CPU->GPU copies on each decoder iteration.
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("use_cache_input", "2:encoder_output");
auto err =
::executorch::runtime::set_option("CudaBackend", opts.view());
if (err != ::executorch::runtime::Error::Ok) {
ET_LOG(Info, "Failed to set use_cache_input option (backend may not support caching)");
}
}

// Add some green coloring for the first generated token
// token_callback("\033[1;32m");
while (generated_tokens < config.max_new_tokens) {
Expand Down Expand Up @@ -304,6 +337,15 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
break;
}
}

// Clear cache input settings after decoder loop completes
// This prevents stale cache from being used in subsequent transcribe() calls
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("clear_cache_input", true);
::executorch::runtime::set_option("CudaBackend", opts.view());
}

// Reset coloring
// token_callback("\033[0m");
// Update stats and print report
Expand Down
Loading