Skip to content

TMA (and potentially other?) bugs on large sequence lengths [blackwell] #315

@coconutruben

Description

@coconutruben

Hi there,

Below is a simple (vibe coded) reproduction script that we've collected for some studies we did for larger sequence lengths and the NATTEN kernel. Most of these fail on all my blackwell machines, some only fail occasionally (unclear what the env conditions are necessary for this)

I'm happy to provide more data/help pinpoint this, but initially, are you guys able to see this on your end as well?

  • where repros/natten_large.py is just the script pasted here
"""Reproduce NATTEN illegal memory access on large 3D inputs.

Configs: hidden_size=8192, head_dim=128, dtype=bf16, kernel_size=(-1,-1,-1) (full attention),
temporally_causal=True, is_causal=(True, False, False).

Two modes:
  1. Full-sequence (reference): na3d(q, k, v) with q/k/v = [B, T, H, W, heads, dim]
  2. Chunked (contrast): na3d(q, k, v, additional_keys=..., additional_values=...)
     with q/k/v = [B, ct, H, W, heads, dim], additional_keys = [B, past_tokens, heads, dim]

Failing invocations (all CUDA illegal memory access):
  # Full-sequence failures (kernel_size = (T, H, W)), sorted by seq_len:
  python repros/natten_large.py --mode full --vf 240  --ih 1024 --nh 64  # seq=245760
  python repros/natten_large.py --mode full --vf 1080 --ih 512  --nh 64  # seq=276480
  python repros/natten_large.py --mode full --vf 360  --ih 1024 --nh 64  # seq=368640
  python repros/natten_large.py --mode full --vf 1440 --ih 512  --nh 64  # seq=368640
  python repros/natten_large.py --mode full --vf 720  --ih 768  --nh 64  # seq=414720
  python repros/natten_large.py --mode full --vf 1080 --ih 768  --nh 32  # seq=622080
  python repros/natten_large.py --mode full --vf 1080 --ih 768  --nh 64  # seq=622080
  python repros/natten_large.py --mode full --vf 720  --ih 1024 --nh 32  # seq=737280
  python repros/natten_large.py --mode full --vf 720  --ih 1024 --nh 64  # seq=737280
  python repros/natten_large.py --mode full --vf 1440 --ih 768  --nh 32  # seq=829440
  python repros/natten_large.py --mode full --vf 1440 --ih 768  --nh 64  # seq=829440
  python repros/natten_large.py --mode full --vf 1080 --ih 1024 --nh 32  # seq=1105920
  python repros/natten_large.py --mode full --vf 1080 --ih 1024 --nh 64  # seq=1105920
  python repros/natten_large.py --mode full --vf 1440 --ih 1024 --nh 32  # seq=1474560
  python repros/natten_large.py --mode full --vf 1440 --ih 1024 --nh 64  # seq=1474560

  # Chunked failures (chunk_time=6, kernel_size = (6, H, W)), sorted by seq_len:
  python repros/natten_large.py --mode chunked --vf 1080 --ih 512  --nh 32  # seq=276480
  python repros/natten_large.py --mode chunked --vf 1080 --ih 512  --nh 64  # seq=276480
  python repros/natten_large.py --mode chunked --vf 360  --ih 1024 --nh 32  # seq=368640
  python repros/natten_large.py --mode chunked --vf 360  --ih 1024 --nh 64  # seq=368640
  python repros/natten_large.py --mode chunked --vf 1440 --ih 512  --nh 32  # seq=368640
  python repros/natten_large.py --mode chunked --vf 1440 --ih 512  --nh 64  # seq=368640
  python repros/natten_large.py --mode chunked --vf 720  --ih 768  --nh 32  # seq=414720
  python repros/natten_large.py --mode chunked --vf 720  --ih 768  --nh 64  # seq=414720
  python repros/natten_large.py --mode chunked --vf 1080 --ih 768  --nh 32  # seq=622080
  python repros/natten_large.py --mode chunked --vf 1080 --ih 768  --nh 64  # seq=622080
  python repros/natten_large.py --mode chunked --vf 720  --ih 1024 --nh 32  # seq=737280
  python repros/natten_large.py --mode chunked --vf 720  --ih 1024 --nh 64  # seq=737280
  python repros/natten_large.py --mode chunked --vf 1080 --ih 1024 --nh 32  # seq=1105920
  python repros/natten_large.py --mode chunked --vf 1080 --ih 1024 --nh 64  # seq=1105920
  python repros/natten_large.py --mode chunked --vf 1440 --ih 1024 --nh 32  # seq=1474560
  python repros/natten_large.py --mode chunked --vf 1440 --ih 1024 --nh 64  # seq=1474560
  python repros/natten_large.py --mode chunked --vf 1440 --ih 768  --nh 64  # seq=829440

  # Passing at boundary (for comparison):
  python repros/natten_large.py --mode full    --vf 240  --ih 1024 --nh 32  # seq=245760  OK
  python repros/natten_large.py --mode full    --vf 360  --ih 1024 --nh 32  # seq=368640  OK
  python repros/natten_large.py --mode full    --vf 1080 --ih 512  --nh 32  # seq=276480  OK
  python repros/natten_large.py --mode full    --vf 1440 --ih 512  --nh 32  # seq=368640  OK
  python repros/natten_large.py --mode full    --vf 720  --ih 768  --nh 32  # seq=414720  OK
  python repros/natten_large.py --mode chunked --vf 240  --ih 1024 --nh 64  # seq=245760  OK
  python repros/natten_large.py --mode chunked --vf 360  --ih 768  --nh 32  # seq=207360  OK
  python repros/natten_large.py --mode chunked --vf 360  --ih 768  --nh 64  # seq=207360  OK
"""

import argparse
import math

import torch
from natten.functional import na3d


def run_full_sequence(
    B: int, T: int, H: int, W: int, num_heads: int, head_dim: int,
) -> None:
    """Full-sequence na3d: q/k/v = [B, T, H, W, heads, dim], kernel_size=(T, H, W)."""
    seq = T * H * W
    print(f"Full-sequence: B={B} T={T} H={H} W={W} heads={num_heads} dim={head_dim} seq={seq:,}")
    print(f"  kernel_size=({T}, {H}, {W}), is_causal=(True, False, False)")

    q = torch.randn(B, T, H, W, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
    k = torch.randn_like(q)
    v = torch.randn_like(q)
    scale = 1.0 / math.sqrt(head_dim)

    mem_gb = torch.cuda.memory_allocated() / 1e9
    print(f"  Allocated before na3d: {mem_gb:.2f} GB")

    torch.cuda.synchronize()
    print("  Running na3d...")
    out = na3d(
        q, k, v,
        kernel_size=(T, H, W),
        dilation=(1, 1, 1),
        stride=(1, 1, 1),
        scale=scale,
        is_causal=(True, False, False),
    )
    torch.cuda.synchronize()
    print(f"  OK — output shape: {out.shape}")


def run_chunked(
    B: int, T: int, H: int, W: int, num_heads: int, head_dim: int,
    chunk_time: int = 6,
) -> None:
    """Chunked na3d: iterate chunks, each with additional_keys/values from KV cache."""
    num_chunks = T // chunk_time
    spatial_per_chunk = chunk_time * H * W
    total_seq = T * H * W
    scale = 1.0 / math.sqrt(head_dim)

    print(f"Chunked: B={B} T={T} H={H} W={W} heads={num_heads} dim={head_dim} seq={total_seq:,}")
    print(f"  chunk_time={chunk_time}, num_chunks={num_chunks}, tokens/chunk={spatial_per_chunk:,}")
    print(f"  kernel_size=({chunk_time}, {H}, {W}), is_causal=(True, False, False)")

    # Pre-allocate KV cache
    cache_k = torch.empty(B, total_seq, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
    cache_v = torch.empty_like(cache_k)

    mem_gb = torch.cuda.memory_allocated() / 1e9
    print(f"  KV cache allocated: {mem_gb:.2f} GB")

    for i in range(num_chunks):
        q = torch.randn(B, chunk_time, H, W, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
        k = torch.randn_like(q)
        v = torch.randn_like(q)

        # Write into cache
        k_flat = k.reshape(B, spatial_per_chunk, num_heads, head_dim)
        v_flat = v.reshape(B, spatial_per_chunk, num_heads, head_dim)
        start = i * spatial_per_chunk
        end = (i + 1) * spatial_per_chunk
        cache_k[:, start:end] = k_flat
        cache_v[:, start:end] = v_flat

        # Additional context from past chunks
        if i > 0:
            add_k = cache_k[:, :start]
            add_v = cache_v[:, :start]
        else:
            add_k = add_v = None

        past_tokens = start if i > 0 else 0
        print(f"  Chunk {i+1}/{num_chunks}: past_tokens={past_tokens:,}", end="", flush=True)

        out = na3d(
            q, k, v,
            kernel_size=(chunk_time, H, W),
            dilation=(1, 1, 1),
            stride=(1, 1, 1),
            scale=scale,
            is_causal=(True, False, False),
            additional_keys=add_k,
            additional_values=add_v,
        )
        torch.cuda.synchronize()
        print(f" — OK shape={out.shape}")

    print(f"  All {num_chunks} chunks passed.")


def main() -> None:
    parser = argparse.ArgumentParser(description="Reproduce NATTEN large 3D input failures")
    parser.add_argument("--mode", choices=["full", "chunked"], required=True,
                        help="full = full-sequence na3d, chunked = chunked with KV cache")
    parser.add_argument("--vf", type=int, required=True, help="video_frames (user-facing)")
    parser.add_argument("--ih", type=int, required=True, help="image_height (user-facing)")
    parser.add_argument("--iw", type=int, default=None, help="image_width (default: same as --ih)")
    parser.add_argument("--nh", type=int, default=64, help="num_heads (default: 64)")
    parser.add_argument("--hd", type=int, default=128, help="head_dim (default: 128)")
    parser.add_argument("--batch", type=int, default=1, help="batch size (default: 1)")
    parser.add_argument("--chunk-time", type=int, default=6, help="chunk_time for chunked mode (default: 6)")
    parser.add_argument("--temporal-compression", type=int, default=4, help="temporal compression (default: 4)")
    parser.add_argument("--spatial-compression", type=int, default=16, help="spatial compression (default: 16)")
    args = parser.parse_args()

    if args.iw is None:
        args.iw = args.ih

    T = args.vf // args.temporal_compression
    H = args.ih // args.spatial_compression
    W = args.iw // args.spatial_compression

    assert args.vf % args.temporal_compression == 0, f"vf={args.vf} not divisible by {args.temporal_compression}"
    assert args.ih % args.spatial_compression == 0, f"ih={args.ih} not divisible by {args.spatial_compression}"
    assert args.iw % args.spatial_compression == 0, f"iw={args.iw} not divisible by {args.spatial_compression}"

    if args.mode == "chunked":
        assert T % args.chunk_time == 0, f"T={T} not divisible by chunk_time={args.chunk_time}"

    print(f"Latent dims: T={T} H={H} W={W} (from vf={args.vf} ih={args.ih} iw={args.iw})")
    print(f"Total seq_len: {T * H * W:,}")
    print()

    if args.mode == "full":
        run_full_sequence(args.batch, T, H, W, args.nh, args.hd)
    else:
        run_chunked(args.batch, T, H, W, args.nh, args.hd, args.chunk_time)


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions