Skip to content

Token‐Shift & Short Convolution: Cross‐Rank Cache Passing in Pipeline Training

Zhiyuan Li edited this page Aug 22, 2025 · 1 revision

1. Token-Shift Operator

Purpose Computes first-order differences across the time dimension:

For every position t   delta[t] = x[t-1] – x[t]   (first token uses -x[0] unless a cache is provided).

1.1 API

from fla.modules.token_shift import token_shift

out = token_shift(
    x: torch.Tensor,                    # [B, T, D]
    cu_seqlens: Optional[LongTensor] = None,   # [N+1]
    cache: Optional[Tensor] = None,            # [N, D]
    output_cache: bool = False
) -> Tensor | Tuple[Tensor, Tensor]
Argument Meaning
x Input tensor. If cu_seqlens is given, B must be 1 and T == cu_seqlens[-1].
cu_seqlens Cumulative lengths of N packed sequences.
cache Last token from prior segment for each sequence ([N, D]).
output_cache If True, also return the updated cache ([N, 1, D]).

2. Short Convolution (ShortConvolution)

A depth-wise causal 1-D convolution (kernel size W). Two runtimes: Triton (full features) or CUDA (faster, restricted).

2.1 Construction

from fla.modules.convolution import ShortConvolution

conv = ShortConvolution(
    hidden_size=D,
    kernel_size=W,
    bias=bool,
    activation="swish" | "silu" | None,
    backend="triton" | "cuda",
    device=torch.device,
    dtype=torch.dtype
)

2.2 Forward Pass

tri, cache_out = conv(
    x: torch.Tensor,                # [B, T, D]
    residual: Optional[Tensor]=None,  # [B, T, D]
    cache: Optional[Tensor]=None,      # [N, D, W]
    output_final_state: bool=False
) -> Tuple[Tensor, Optional[Tensor]]
Return Shape Meaning
tri [B, T, D] Convolved sequence.
cache_out [N, D, W] Final convolution state (last W tokens/sequence).

2.3 Backend Compatibility Check

At runtime the layer automatically falls back to Triton with a warning if:

  • backend == "cuda" and any of – (cu_seqlens is not None) and (cache is not None)(cu_seqlens is not None) and output_final_state=True

2.4 Incremental Decoding

# cache shape: [B, D, W] or [N, D, W]
y_t, cache = conv.step(
    x_t: Tensor,          # [B, 1, D] or [1, D]
    residual: Tensor | None,
    cache: Tensor         # updated in-place
)

2.5 Cache & Variable-Length Notes

Scenario Cache Shape Rule
Fixed length [B, D, W] Direct pass between ranks.
Variable length (cu_seqlens) [N, D, W] Split x with B=1; each sequence owns its slice.
Cross-rank transfer Copy last W tokens of any sequence that crosses a rank boundary; otherwise zero-init.

3. Practical Checklist

Item Recommendation
Backend Prefer Triton: supports varlen, cache, gradient flow.
Activations "swish" or "silu" only (or None).
Cache sizes Token-Shift: [N, D] (single token).
ShortConvolution: [N, D, W] (last W tokens).
Varlen mode Always B=1 when cu_seqlens is provided.

4. One-Line Summary

  • Token-Shift → first-order diff, one-token cache.
  • ShortConvolution → depth-wise causal conv, W-token cache. Both support gradient flow and seamless state hand-off in multi-rank pipelines.