Skip to content

Conversation

@mergennachin
Copy link
Contributor

@mergennachin mergennachin commented Dec 3, 2025

Summary

This PR implements a "keep on device" optimization for the CUDA backend that stores encoder output tensors on the GPU and reuses them via fast device-to-device (D2D) copies
during decoder iterations. This avoids redundant CPU→GPU transfers in encoder-decoder models like Whisper.

Motivation

In encoder-decoder architectures like Whisper, the encoder runs once and produces an output tensor that the decoder consumes on every iteration. Without this optimization,
the flow is:

Encoder: CPU input → [H2D copy] → GPU compute → [D2H copy] → CPU output
Decoder (×N tokens): CPU inputs (including encoder output) → [H2D copy] → GPU compute → [D2H copy] → CPU output

The encoder output (~2.3 MB for Whisper) is copied from CPU→GPU on every decoder iteration, even though it never changes. With N=109 tokens, that's 109 redundant H2D copies.

Design

The optimization introduces a simple GPU tensor storage mechanism:

Encoder: CPU input → [H2D copy] → GPU compute → [store on GPU] → [D2H copy] → CPU output
Decoder (×N tokens): CPU inputs → [D2D copy for encoder output, H2D for others] → GPU compute → [D2H copy] → CPU output

Key Design Decisions

  1. Name-based storage: Tensors are stored by name (e.g., "encoder_output") rather than by slot index, making the API more intuitive and less fragile.
  2. Size-based matching: When looking for a stored tensor to use as input, the backend matches by tensor size rather than requiring explicit slot mapping.
  3. Explicit opt-in: The optimization is controlled via set_option() calls, making it backwards compatible and non-intrusive to existing code paths.
  4. RAII cleanup: A TensorCleanup guard ensures GPU tensors are freed on all exit paths, preventing memory leaks when errors occur during execution.

API

Three backend options control the behavior:

Option Type Description
store_output string Store first output tensor under this name
use_stored_input string Use stored tensor for inputs matching by size
reset_stored_input bool Clear the input setting (tensor remains in memory)

Lifecycle

  1. Before encoder: set_option("store_output", "encoder_output")
  2. Run encoder: output tensor stored on GPU
  3. Before decoder loop: set_option("use_stored_input", "encoder_output")
  4. Run decoder ×N: each iteration uses D2D copy for encoder output
  5. After decoder loop: set_option("reset_stored_input", true)
  6. On destroy(): all stored tensors freed

Changes

backends/cuda/runtime/cuda_backend.cpp (+212 lines)

  • Added GpuTensorRef struct to hold GPU tensor references with ownership
  • Added gpu_tensors_ map for named tensor storage with documented lifetime contract
  • Implemented set_option() with validation for option types
  • Added RAII TensorCleanup guard to prevent memory leaks on error paths
  • Added validation for single-output constraint
  • Cleanup of stored tensors in destroy()

extension/asr/runner/runner.cpp (+42 lines)

  • Set store_output before encoder execution
  • Set use_stored_input before decoder loop
  • Reset after decoder loop completes
  • Consistent error logging at Warning level

Performance

Profiling with nsys confirms the optimization works:

Operation Count Total Size Avg Time
H2D copies 722 521.7 MB 125 µs
D2D copies 109 251.1 MB 13.6 µs

The 109 D2D copies (one per decoder token) each transfer 2.304 MB (encoder output) ~9x faster than equivalent H2D copies would.

Test plan

  • Build succeeds
  • Whisper transcription produces correct output
  • nsys profile confirms D2D copies are occurring
  • No memory leaks (RAII cleanup on all paths)

Summary:

In encoder-decoder models like Whisper, the encoder output tensor is
used as input to every decoder iteration, and doing unnecessary
CPU->GPU->CPU->GPU cpies.

Implemented a "keep on device" caching mechanism in the CUDA backend
that:

-  Caches encoder output in persistent GPU memory after the encoder runs
-  Uses fast GPU-to-GPU copies decoder iterations instead of slow CPU-to-GPU copies

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copilot AI review requested due to automatic review settings December 3, 2025 01:23
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16060

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 137e6da with merge base 33ec615 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 3, 2025
@github-actions
Copy link

github-actions bot commented Dec 3, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements GPU device caching for encoder output in the CUDA backend to optimize ASR (Automatic Speech Recognition) model inference. The caching mechanism avoids redundant CPU-to-GPU memory copies of encoder output during decoder iterations by keeping the encoder output on the GPU and using fast GPU-to-GPU copies instead.

Key Changes:

  • Added a global GPU tensor cache (g_gpu_tensors) to store encoder outputs on the GPU across multiple execute() calls
  • Implemented backend options API (cache_output, use_cache_input, clear_cache_input) to control caching behavior
  • Modified the ASR runner to set caching options before encoder execution and reuse cached tensors during the decoder loop

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 16 comments.

File Description
backends/cuda/runtime/cuda_backend.cpp Implements GPU tensor caching infrastructure with set_option API, memory management for cached tensors, and GPU-to-GPU copy logic for cached inputs
extension/asr/runner/runner.cpp Adds cache control flow: sets cache_output before encoder execution, sets use_cache_input before decoder loop, and clears settings after completion

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings December 3, 2025 14:15
@mergennachin mergennachin changed the title [WIP][CUDA]: GPU Device Caching for Encoder Output in CUDA Backend [CUDA]: GPU Device Caching for Encoder Output in CUDA Backend Dec 3, 2025
@mergennachin mergennachin requested review from Gasoonjia, JacobSzwejbka and larryliu0820 and removed request for Copilot December 3, 2025 14:31
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings December 3, 2025 15:12
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings December 3, 2025 18:06
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +66 to +67
// This backend supports storing GPU tensors between execute() calls to enable
// device-to-device (D2D) copies instead of slower host-to-device (H2D)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why do we still need to copy? Can you just make_tensor using the GPU data pointer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried not copying initially and it was segfaulting. Because they're completely two different graphs, the output from the first graph and input from the second graph had different underlying layout assumptions, so had to explicitly copy.

Comment on lines +92 to +102
// TYPICAL USAGE PATTERN (encoder-decoder model):
//
// 1. Before encoder: set_option("store_output", "encoder_output")
// 2. Execute encoder (output is stored on GPU)
// 3. Before decoder loop: set_option("use_stored_input", "encoder_output")
// 4. Execute decoder N times (D2D copies for encoder output input)
// 5. After decoder loop:
// set_option("reset_stored_input", true)
// set_option("clear_stored_tensor", "encoder_output")
//
// ============================================================================
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand the intention, is it trying to use backend option to have let the method encode/decode share the output memory? In an ideal word, if encode/decode methods can share memory planning, does it mean we don't have to use this?

Copy link
Contributor

@JacobSzwejbka JacobSzwejbka Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it trying to use backend option to have let the method encode/decode share the output memory?

Its trying to avoid cpu->gpu copies. If we had device tensor we wouldnt need this, but its wip and perf here is time sensitive so Mergen is hacking around it until its properly fixed upstream

Copy link
Contributor

@cccclai cccclai Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, that seems fine to me. Maybe worth adding this as part of the comment because I can't tell from the PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants