Skip to content

Commit 37c47d4

Browse files
committed
[CUDA]: GPU Device Caching for Encoder Output in CUDA Backend
Summary: In encoder-decoder models like Whisper, the encoder output tensor is used as input to every decoder iteration, and doing unnecessary CPU->GPU->CPU->GPU cpies. Implemented a "keep on device" caching mechanism in the CUDA backend that: - Caches encoder output in persistent GPU memory after the encoder runs - Uses fast GPU-to-GPU copies decoder iterations instead of slow CPU-to-GPU copies Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 33ec615 commit 37c47d4

File tree

2 files changed

+181
-3
lines changed

2 files changed

+181
-3
lines changed

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 155 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cuda_runtime.h>
1010
#include <executorch/runtime/backend/interface.h>
11+
#include <executorch/runtime/backend/options.h>
1112
#include <executorch/runtime/core/error.h>
1213
#include <executorch/runtime/core/evalue.h>
1314
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
@@ -16,6 +17,7 @@
1617
#include <filesystem>
1718
#include <fstream>
1819
#include <string>
20+
#include <unordered_map>
1921
#include <vector>
2022

2123
// Include our shim layer headers
@@ -46,9 +48,27 @@ using executorch::runtime::Result;
4648
using executorch::runtime::Span;
4749
using executorch::runtime::etensor::Tensor;
4850

51+
// Structure to hold cached GPU tensor data for "keep on device" optimization
52+
struct CachedGpuData {
53+
void* data_ptr; // GPU memory pointer
54+
size_t size_bytes; // Total size in bytes
55+
int32_t scalar_type; // Data type
56+
std::vector<int64_t> sizes; // Original shape
57+
};
58+
59+
// Global device cache - maps name to cached GPU data
60+
// Using raw GPU pointers instead of tensor handles for format independence
61+
static std::unordered_map<std::string, CachedGpuData> g_device_cache;
62+
4963
class ET_EXPERIMENTAL CudaBackend final
5064
: public ::executorch::runtime::BackendInterface {
5165
private:
66+
// Cache control options (set via set_option before execute)
67+
mutable int cache_output_slot_ = -1; // Which output slot to cache (-1 = none)
68+
mutable std::string cache_output_name_; // Name to cache output under
69+
mutable int use_cache_input_slot_ = -1; // Which input slot to use cache for (-1 = none)
70+
mutable std::string use_cache_input_name_; // Name of cached tensor to use
71+
5272
Error load_function_pointers_into_handle(
5373
void* so_handle,
5474
AOTIDelegateHandle* handle) const {
@@ -91,6 +111,51 @@ class ET_EXPERIMENTAL CudaBackend final
91111
return 1;
92112
}
93113

114+
Error set_option(
115+
__ET_UNUSED executorch::runtime::BackendOptionContext& context,
116+
const executorch::runtime::Span<executorch::runtime::BackendOption>&
117+
backend_options) override {
118+
for (size_t i = 0; i < backend_options.size(); i++) {
119+
const auto& option = backend_options[i];
120+
// Handle cache_output: "slot:name" format (e.g., "0:encoder_output")
121+
if (strcmp(option.key, "cache_output") == 0) {
122+
if (auto* arr = std::get_if<
123+
std::array<char, executorch::runtime::kMaxOptionValueLength>>(
124+
&option.value)) {
125+
std::string val(arr->data());
126+
auto colon_pos = val.find(':');
127+
if (colon_pos != std::string::npos) {
128+
cache_output_slot_ = std::stoi(val.substr(0, colon_pos));
129+
cache_output_name_ = val.substr(colon_pos + 1);
130+
}
131+
}
132+
}
133+
// Handle use_cache_input: "slot:name" format (e.g., "1:encoder_output")
134+
else if (strcmp(option.key, "use_cache_input") == 0) {
135+
if (auto* arr = std::get_if<
136+
std::array<char, executorch::runtime::kMaxOptionValueLength>>(
137+
&option.value)) {
138+
std::string val(arr->data());
139+
auto colon_pos = val.find(':');
140+
if (colon_pos != std::string::npos) {
141+
use_cache_input_slot_ = std::stoi(val.substr(0, colon_pos));
142+
use_cache_input_name_ = val.substr(colon_pos + 1);
143+
}
144+
}
145+
}
146+
// Handle clear_cache_input: reset input cache settings
147+
else if (strcmp(option.key, "clear_cache_input") == 0) {
148+
if (auto* val = std::get_if<bool>(&option.value)) {
149+
if (*val) {
150+
use_cache_input_slot_ = -1;
151+
use_cache_input_name_.clear();
152+
}
153+
}
154+
}
155+
}
156+
return Error::Ok;
157+
}
158+
94159
// Once per loaded binary blob
95160
Result<DelegateHandle*> init(
96161
BackendInitContext& context,
@@ -223,14 +288,14 @@ class ET_EXPERIMENTAL CudaBackend final
223288
n_outputs); // GPU tensors for kernel output
224289

225290
// Process input tensors: ExecuTorch provides CPU tensors, create GPU
226-
// copies
291+
// copies. For cached inputs, use GPU-to-GPU copy instead of CPU-to-GPU.
227292
for (int i = 0; i < n_inputs; i++) {
228293
// Get tensor dimensions and properties from ExecuTorch CPU tensor
229294
auto cpu_tensor = &(args[i]->toTensor());
230295
auto sizes = cpu_tensor->sizes();
231296
auto scalar_type = cpu_tensor->scalar_type();
232297

233-
// Create GPU tensor with same shape
298+
// Create GPU tensor with same shape (always needed for AOTI format)
234299
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
235300

236301
AOTITensorHandle gpu_input_handle;
@@ -251,7 +316,43 @@ class ET_EXPERIMENTAL CudaBackend final
251316

252317
gpu_inputs[i] = gpu_input_handle;
253318

254-
// Copy data from CPU to GPU
319+
// Check if this input slot should use cached GPU data
320+
if (i == use_cache_input_slot_ && !use_cache_input_name_.empty()) {
321+
auto cache_it = g_device_cache.find(use_cache_input_name_);
322+
if (cache_it != g_device_cache.end()) {
323+
const CachedGpuData& cached = cache_it->second;
324+
// GPU-to-GPU copy: fast DMA transfer, normalizes tensor format
325+
size_t numel = gpu_inputs[i]->numel();
326+
size_t elem_size = gpu_inputs[i]->element_size();
327+
size_t copy_bytes = numel * elem_size;
328+
329+
ET_CHECK_OR_RETURN_ERROR(
330+
copy_bytes == cached.size_bytes,
331+
Internal,
332+
"Cached tensor size mismatch: expected %zu bytes, got %zu",
333+
copy_bytes,
334+
cached.size_bytes);
335+
336+
cudaError_t cuda_err = cudaMemcpy(
337+
gpu_inputs[i]->data_ptr(),
338+
cached.data_ptr,
339+
copy_bytes,
340+
cudaMemcpyDeviceToDevice);
341+
342+
ET_CHECK_OR_RETURN_ERROR(
343+
cuda_err == cudaSuccess,
344+
Internal,
345+
"Failed GPU-to-GPU copy for cached input %d: %s",
346+
i,
347+
cudaGetErrorString(cuda_err));
348+
349+
// Skip the CPU-to-GPU copy below
350+
continue;
351+
}
352+
// Cache miss: fall through to normal CPU-to-GPU copy
353+
}
354+
355+
// Copy data from CPU to GPU (normal path)
255356
ET_CHECK_OR_RETURN_ERROR(
256357
aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok,
257358
Internal,
@@ -303,6 +404,57 @@ class ET_EXPERIMENTAL CudaBackend final
303404
"AOTInductorModelContainerRun failed with error code %d",
304405
error);
305406

407+
// Cache output GPU tensor data if requested
408+
// We store the raw GPU pointer for later GPU-to-GPU copy
409+
if (cache_output_slot_ >= 0 && cache_output_slot_ < static_cast<int>(n_outputs) &&
410+
!cache_output_name_.empty()) {
411+
auto* gpu_tensor = gpu_outputs[cache_output_slot_];
412+
size_t numel = gpu_tensor->numel();
413+
size_t elem_size = gpu_tensor->element_size();
414+
size_t size_bytes = numel * elem_size;
415+
416+
// Allocate persistent GPU memory for the cache
417+
void* cache_ptr = nullptr;
418+
cudaError_t alloc_err = cudaMalloc(&cache_ptr, size_bytes);
419+
ET_CHECK_OR_RETURN_ERROR(
420+
alloc_err == cudaSuccess,
421+
Internal,
422+
"Failed to allocate GPU cache memory: %s",
423+
cudaGetErrorString(alloc_err));
424+
425+
// Copy from tensor to cache (GPU-to-GPU)
426+
cudaError_t copy_err = cudaMemcpy(
427+
cache_ptr,
428+
gpu_tensor->data_ptr(),
429+
size_bytes,
430+
cudaMemcpyDeviceToDevice);
431+
ET_CHECK_OR_RETURN_ERROR(
432+
copy_err == cudaSuccess,
433+
Internal,
434+
"Failed to copy output to GPU cache: %s",
435+
cudaGetErrorString(copy_err));
436+
437+
// Free old cache if exists
438+
auto old_it = g_device_cache.find(cache_output_name_);
439+
if (old_it != g_device_cache.end()) {
440+
cudaFree(old_it->second.data_ptr);
441+
g_device_cache.erase(old_it);
442+
}
443+
444+
// Store in cache
445+
CachedGpuData cached;
446+
cached.data_ptr = cache_ptr;
447+
cached.size_bytes = size_bytes;
448+
cached.scalar_type = static_cast<int32_t>(gpu_tensor->scalar_type());
449+
auto sizes = gpu_tensor->sizes();
450+
cached.sizes.assign(sizes.begin(), sizes.end());
451+
g_device_cache[cache_output_name_] = std::move(cached);
452+
453+
// Reset cache_output settings after caching
454+
cache_output_slot_ = -1;
455+
cache_output_name_.clear();
456+
}
457+
306458
// Copy GPU output results back to CPU output tensors
307459
for (int i = 0; i < n_outputs; i++) {
308460
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());

extension/asr/runner/runner.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <executorch/extension/llm/runner/util.h>
1717
#include <executorch/extension/llm/sampler/util.h>
1818
#include <executorch/extension/tensor/tensor_ptr_maker.h>
19+
#include <executorch/runtime/backend/interface.h>
20+
#include <executorch/runtime/backend/options.h>
1921
#include <executorch/runtime/core/evalue.h>
2022
#include <executorch/runtime/platform/assert.h>
2123
#include <executorch/runtime/platform/log.h>
@@ -196,6 +198,17 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
196198
}
197199
}
198200

201+
// Tell CUDA backend to cache encoder output (slot 0) as "encoder_output"
202+
{
203+
::executorch::runtime::BackendOptions<1> opts;
204+
opts.set_option("cache_output", "0:encoder_output");
205+
auto err =
206+
::executorch::runtime::set_option("CudaBackend", opts.view());
207+
if (err != ::executorch::runtime::Error::Ok) {
208+
ET_LOG(Info, "Failed to set cache_output option (backend may not support caching)");
209+
}
210+
}
211+
199212
auto encoder_result =
200213
module_->execute(kEncoderMethodName, preprocessed_features);
201214
ET_CHECK_OK_OR_RETURN_ERROR(encoder_result.error());
@@ -249,6 +262,19 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
249262
decoder_inputs.emplace_back(decoder_input_ptr);
250263
decoder_inputs.emplace_back(encoder_output_ptr);
251264
decoder_inputs.emplace_back(cache_position_ptr);
265+
266+
// Tell CUDA backend to use cached encoder output for decoder input slot 2
267+
// Note: Decoder input order in AOTI is: input_ids[0], cache_position[1], encoder_output[2]
268+
{
269+
::executorch::runtime::BackendOptions<1> opts;
270+
opts.set_option("use_cache_input", "2:encoder_output");
271+
auto err =
272+
::executorch::runtime::set_option("CudaBackend", opts.view());
273+
if (err != ::executorch::runtime::Error::Ok) {
274+
ET_LOG(Info, "Failed to set use_cache_input option (backend may not support caching)");
275+
}
276+
}
277+
252278
// Add some green coloring for the first generated token
253279
// token_callback("\033[1;32m");
254280
while (generated_tokens < config.max_new_tokens) {

0 commit comments

Comments
 (0)