-
Notifications
You must be signed in to change notification settings - Fork 57
Open
Description
Hi there,
Below is a simple (vibe coded) reproduction script that we've collected for some studies we did for larger sequence lengths and the NATTEN kernel. Most of these fail on all my blackwell machines, some only fail occasionally (unclear what the env conditions are necessary for this)
I'm happy to provide more data/help pinpoint this, but initially, are you guys able to see this on your end as well?
- where
repros/natten_large.pyis just the script pasted here
"""Reproduce NATTEN illegal memory access on large 3D inputs.
Configs: hidden_size=8192, head_dim=128, dtype=bf16, kernel_size=(-1,-1,-1) (full attention),
temporally_causal=True, is_causal=(True, False, False).
Two modes:
1. Full-sequence (reference): na3d(q, k, v) with q/k/v = [B, T, H, W, heads, dim]
2. Chunked (contrast): na3d(q, k, v, additional_keys=..., additional_values=...)
with q/k/v = [B, ct, H, W, heads, dim], additional_keys = [B, past_tokens, heads, dim]
Failing invocations (all CUDA illegal memory access):
# Full-sequence failures (kernel_size = (T, H, W)), sorted by seq_len:
python repros/natten_large.py --mode full --vf 240 --ih 1024 --nh 64 # seq=245760
python repros/natten_large.py --mode full --vf 1080 --ih 512 --nh 64 # seq=276480
python repros/natten_large.py --mode full --vf 360 --ih 1024 --nh 64 # seq=368640
python repros/natten_large.py --mode full --vf 1440 --ih 512 --nh 64 # seq=368640
python repros/natten_large.py --mode full --vf 720 --ih 768 --nh 64 # seq=414720
python repros/natten_large.py --mode full --vf 1080 --ih 768 --nh 32 # seq=622080
python repros/natten_large.py --mode full --vf 1080 --ih 768 --nh 64 # seq=622080
python repros/natten_large.py --mode full --vf 720 --ih 1024 --nh 32 # seq=737280
python repros/natten_large.py --mode full --vf 720 --ih 1024 --nh 64 # seq=737280
python repros/natten_large.py --mode full --vf 1440 --ih 768 --nh 32 # seq=829440
python repros/natten_large.py --mode full --vf 1440 --ih 768 --nh 64 # seq=829440
python repros/natten_large.py --mode full --vf 1080 --ih 1024 --nh 32 # seq=1105920
python repros/natten_large.py --mode full --vf 1080 --ih 1024 --nh 64 # seq=1105920
python repros/natten_large.py --mode full --vf 1440 --ih 1024 --nh 32 # seq=1474560
python repros/natten_large.py --mode full --vf 1440 --ih 1024 --nh 64 # seq=1474560
# Chunked failures (chunk_time=6, kernel_size = (6, H, W)), sorted by seq_len:
python repros/natten_large.py --mode chunked --vf 1080 --ih 512 --nh 32 # seq=276480
python repros/natten_large.py --mode chunked --vf 1080 --ih 512 --nh 64 # seq=276480
python repros/natten_large.py --mode chunked --vf 360 --ih 1024 --nh 32 # seq=368640
python repros/natten_large.py --mode chunked --vf 360 --ih 1024 --nh 64 # seq=368640
python repros/natten_large.py --mode chunked --vf 1440 --ih 512 --nh 32 # seq=368640
python repros/natten_large.py --mode chunked --vf 1440 --ih 512 --nh 64 # seq=368640
python repros/natten_large.py --mode chunked --vf 720 --ih 768 --nh 32 # seq=414720
python repros/natten_large.py --mode chunked --vf 720 --ih 768 --nh 64 # seq=414720
python repros/natten_large.py --mode chunked --vf 1080 --ih 768 --nh 32 # seq=622080
python repros/natten_large.py --mode chunked --vf 1080 --ih 768 --nh 64 # seq=622080
python repros/natten_large.py --mode chunked --vf 720 --ih 1024 --nh 32 # seq=737280
python repros/natten_large.py --mode chunked --vf 720 --ih 1024 --nh 64 # seq=737280
python repros/natten_large.py --mode chunked --vf 1080 --ih 1024 --nh 32 # seq=1105920
python repros/natten_large.py --mode chunked --vf 1080 --ih 1024 --nh 64 # seq=1105920
python repros/natten_large.py --mode chunked --vf 1440 --ih 1024 --nh 32 # seq=1474560
python repros/natten_large.py --mode chunked --vf 1440 --ih 1024 --nh 64 # seq=1474560
python repros/natten_large.py --mode chunked --vf 1440 --ih 768 --nh 64 # seq=829440
# Passing at boundary (for comparison):
python repros/natten_large.py --mode full --vf 240 --ih 1024 --nh 32 # seq=245760 OK
python repros/natten_large.py --mode full --vf 360 --ih 1024 --nh 32 # seq=368640 OK
python repros/natten_large.py --mode full --vf 1080 --ih 512 --nh 32 # seq=276480 OK
python repros/natten_large.py --mode full --vf 1440 --ih 512 --nh 32 # seq=368640 OK
python repros/natten_large.py --mode full --vf 720 --ih 768 --nh 32 # seq=414720 OK
python repros/natten_large.py --mode chunked --vf 240 --ih 1024 --nh 64 # seq=245760 OK
python repros/natten_large.py --mode chunked --vf 360 --ih 768 --nh 32 # seq=207360 OK
python repros/natten_large.py --mode chunked --vf 360 --ih 768 --nh 64 # seq=207360 OK
"""
import argparse
import math
import torch
from natten.functional import na3d
def run_full_sequence(
B: int, T: int, H: int, W: int, num_heads: int, head_dim: int,
) -> None:
"""Full-sequence na3d: q/k/v = [B, T, H, W, heads, dim], kernel_size=(T, H, W)."""
seq = T * H * W
print(f"Full-sequence: B={B} T={T} H={H} W={W} heads={num_heads} dim={head_dim} seq={seq:,}")
print(f" kernel_size=({T}, {H}, {W}), is_causal=(True, False, False)")
q = torch.randn(B, T, H, W, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
k = torch.randn_like(q)
v = torch.randn_like(q)
scale = 1.0 / math.sqrt(head_dim)
mem_gb = torch.cuda.memory_allocated() / 1e9
print(f" Allocated before na3d: {mem_gb:.2f} GB")
torch.cuda.synchronize()
print(" Running na3d...")
out = na3d(
q, k, v,
kernel_size=(T, H, W),
dilation=(1, 1, 1),
stride=(1, 1, 1),
scale=scale,
is_causal=(True, False, False),
)
torch.cuda.synchronize()
print(f" OK — output shape: {out.shape}")
def run_chunked(
B: int, T: int, H: int, W: int, num_heads: int, head_dim: int,
chunk_time: int = 6,
) -> None:
"""Chunked na3d: iterate chunks, each with additional_keys/values from KV cache."""
num_chunks = T // chunk_time
spatial_per_chunk = chunk_time * H * W
total_seq = T * H * W
scale = 1.0 / math.sqrt(head_dim)
print(f"Chunked: B={B} T={T} H={H} W={W} heads={num_heads} dim={head_dim} seq={total_seq:,}")
print(f" chunk_time={chunk_time}, num_chunks={num_chunks}, tokens/chunk={spatial_per_chunk:,}")
print(f" kernel_size=({chunk_time}, {H}, {W}), is_causal=(True, False, False)")
# Pre-allocate KV cache
cache_k = torch.empty(B, total_seq, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
cache_v = torch.empty_like(cache_k)
mem_gb = torch.cuda.memory_allocated() / 1e9
print(f" KV cache allocated: {mem_gb:.2f} GB")
for i in range(num_chunks):
q = torch.randn(B, chunk_time, H, W, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
k = torch.randn_like(q)
v = torch.randn_like(q)
# Write into cache
k_flat = k.reshape(B, spatial_per_chunk, num_heads, head_dim)
v_flat = v.reshape(B, spatial_per_chunk, num_heads, head_dim)
start = i * spatial_per_chunk
end = (i + 1) * spatial_per_chunk
cache_k[:, start:end] = k_flat
cache_v[:, start:end] = v_flat
# Additional context from past chunks
if i > 0:
add_k = cache_k[:, :start]
add_v = cache_v[:, :start]
else:
add_k = add_v = None
past_tokens = start if i > 0 else 0
print(f" Chunk {i+1}/{num_chunks}: past_tokens={past_tokens:,}", end="", flush=True)
out = na3d(
q, k, v,
kernel_size=(chunk_time, H, W),
dilation=(1, 1, 1),
stride=(1, 1, 1),
scale=scale,
is_causal=(True, False, False),
additional_keys=add_k,
additional_values=add_v,
)
torch.cuda.synchronize()
print(f" — OK shape={out.shape}")
print(f" All {num_chunks} chunks passed.")
def main() -> None:
parser = argparse.ArgumentParser(description="Reproduce NATTEN large 3D input failures")
parser.add_argument("--mode", choices=["full", "chunked"], required=True,
help="full = full-sequence na3d, chunked = chunked with KV cache")
parser.add_argument("--vf", type=int, required=True, help="video_frames (user-facing)")
parser.add_argument("--ih", type=int, required=True, help="image_height (user-facing)")
parser.add_argument("--iw", type=int, default=None, help="image_width (default: same as --ih)")
parser.add_argument("--nh", type=int, default=64, help="num_heads (default: 64)")
parser.add_argument("--hd", type=int, default=128, help="head_dim (default: 128)")
parser.add_argument("--batch", type=int, default=1, help="batch size (default: 1)")
parser.add_argument("--chunk-time", type=int, default=6, help="chunk_time for chunked mode (default: 6)")
parser.add_argument("--temporal-compression", type=int, default=4, help="temporal compression (default: 4)")
parser.add_argument("--spatial-compression", type=int, default=16, help="spatial compression (default: 16)")
args = parser.parse_args()
if args.iw is None:
args.iw = args.ih
T = args.vf // args.temporal_compression
H = args.ih // args.spatial_compression
W = args.iw // args.spatial_compression
assert args.vf % args.temporal_compression == 0, f"vf={args.vf} not divisible by {args.temporal_compression}"
assert args.ih % args.spatial_compression == 0, f"ih={args.ih} not divisible by {args.spatial_compression}"
assert args.iw % args.spatial_compression == 0, f"iw={args.iw} not divisible by {args.spatial_compression}"
if args.mode == "chunked":
assert T % args.chunk_time == 0, f"T={T} not divisible by chunk_time={args.chunk_time}"
print(f"Latent dims: T={T} H={H} W={W} (from vf={args.vf} ih={args.ih} iw={args.iw})")
print(f"Total seq_len: {T * H * W:,}")
print()
if args.mode == "full":
run_full_sequence(args.batch, T, H, W, args.nh, args.hd)
else:
run_chunked(args.batch, T, H, W, args.nh, args.hd, args.chunk_time)
if __name__ == "__main__":
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels