-
Notifications
You must be signed in to change notification settings - Fork 302
Token‐Shift & Short Convolution: Cross‐Rank Cache Passing in Pipeline Training
Zhiyuan Li edited this page Aug 22, 2025
·
1 revision
Purpose Computes first-order differences across the time dimension:
For every position t
delta[t] = x[t-1] – x[t]
(first token uses -x[0] unless a cache is provided).
from fla.modules.token_shift import token_shift
out = token_shift(
x: torch.Tensor, # [B, T, D]
cu_seqlens: Optional[LongTensor] = None, # [N+1]
cache: Optional[Tensor] = None, # [N, D]
output_cache: bool = False
) -> Tensor | Tuple[Tensor, Tensor]| Argument | Meaning |
|---|---|
| x | Input tensor. If cu_seqlens is given, B must be 1 and T == cu_seqlens[-1]. |
| cu_seqlens | Cumulative lengths of N packed sequences. |
| cache | Last token from prior segment for each sequence ([N, D]). |
| output_cache | If True, also return the updated cache ([N, 1, D]). |
A depth-wise causal 1-D convolution (kernel size W).
Two runtimes: Triton (full features) or CUDA (faster, restricted).
from fla.modules.convolution import ShortConvolution
conv = ShortConvolution(
hidden_size=D,
kernel_size=W,
bias=bool,
activation="swish" | "silu" | None,
backend="triton" | "cuda",
device=torch.device,
dtype=torch.dtype
)tri, cache_out = conv(
x: torch.Tensor, # [B, T, D]
residual: Optional[Tensor]=None, # [B, T, D]
cache: Optional[Tensor]=None, # [N, D, W]
output_final_state: bool=False
) -> Tuple[Tensor, Optional[Tensor]]| Return | Shape | Meaning |
|---|---|---|
| tri | [B, T, D] |
Convolved sequence. |
| cache_out | [N, D, W] |
Final convolution state (last W tokens/sequence). |
At runtime the layer automatically falls back to Triton with a warning if:
-
backend == "cuda"and any of –(cu_seqlens is not None) and (cache is not None)–(cu_seqlens is not None) and output_final_state=True
# cache shape: [B, D, W] or [N, D, W]
y_t, cache = conv.step(
x_t: Tensor, # [B, 1, D] or [1, D]
residual: Tensor | None,
cache: Tensor # updated in-place
)| Scenario | Cache Shape | Rule |
|---|---|---|
| Fixed length | [B, D, W] |
Direct pass between ranks. |
Variable length (cu_seqlens) |
[N, D, W] |
Split x with B=1; each sequence owns its slice. |
| Cross-rank transfer | — | Copy last W tokens of any sequence that crosses a rank boundary; otherwise zero-init. |
| Item | Recommendation |
|---|---|
| Backend | Prefer Triton: supports varlen, cache, gradient flow. |
| Activations |
"swish" or "silu" only (or None). |
| Cache sizes | Token-Shift: [N, D] (single token).ShortConvolution: [N, D, W] (last W tokens). |
| Varlen mode | Always B=1 when cu_seqlens is provided. |
- Token-Shift → first-order diff, one-token cache.
-
ShortConvolution → depth-wise causal conv,
W-token cache. Both support gradient flow and seamless state hand-off in multi-rank pipelines.