Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def fused_mha_with_cache(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
freqs_cis: Optional[torch.Tensor],
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""Fused MHA with cache that takes raw input from q, k, v GEMMs."""
# b, s info
Expand Down Expand Up @@ -593,6 +594,7 @@ def fused_mha_fake(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
freqs_cis: torch.Tensor,
logit_cap: Optional[float] = None,
):
return torch.empty_like(q.contiguous())

Expand Down
15 changes: 15 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class PlanParams:
q_dtype: torch.dtype
kv_dtype: torch.dtype
sm_scale: Optional[float] = None
logit_cap: Optional[float] = None

causal: bool = True

Expand Down Expand Up @@ -107,6 +108,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
logits_soft_cap=plan_params.logit_cap,
)

# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
Expand Down Expand Up @@ -143,6 +145,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
logits_soft_cap=plan_params.logit_cap,
)
self.plan_params = plan_params

Expand Down Expand Up @@ -250,6 +253,7 @@ def flashinfer_mha_with_cache(
scale: Optional[float],
k_scale: float,
v_scale: float,
logit_cap: Optional[float],
) -> torch.Tensor:
# reshape to standard [b*s, n_heads, head_dim] layout
head_dim = k_cache.shape[-1]
Expand All @@ -273,6 +277,7 @@ def flashinfer_mha_with_cache(
q_dtype=q.dtype,
kv_dtype=k_cache.dtype,
sm_scale=scale,
logit_cap=logit_cap,
)

# Assuming k_scale = v_scale = 1.0, we just have to cast k and v to fp8 before appending to kv cache
Expand Down Expand Up @@ -327,6 +332,7 @@ def flashinfer_mha_with_cache_fake(
scale: Optional[float],
k_scale: float,
v_scale: float,
logit_cap: Optional[float],
) -> torch.Tensor:
return torch.empty_like(q.contiguous())

Expand Down Expand Up @@ -419,8 +425,17 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
ad_logger.warning("Provided scale is not a float. Using default scale instead.")
scale = None

logit_cap = source_attn_node.kwargs.get("logit_cap", None)
if not (isinstance(logit_cap, float) or logit_cap is None):
ad_logger.debug("Provided logit_cap is not a float or None. Disabling soft-capping.")
logit_cap = None
elif isinstance(logit_cap, float) and logit_cap <= 0:
ad_logger.warning("Provided logit_cap is not positive. Disabling soft-capping.")
logit_cap = None

return [
scale, # softmax scale
1.0, # k_scale
1.0, # v_scale
logit_cap,
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import triton
from triton import language as tl
from triton.language.extra.libdevice import tanh


@triton.jit
Expand Down Expand Up @@ -112,6 +113,7 @@ def gqa_attention_kv_stage1(
V_D_HEAD: tl.constexpr, # Dimension of each key/value head
SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim.
HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
LOGIT_CAP: tl.constexpr = None, # softcapping introduced in the Gemma 2 paper
):
"""Attention kernel to be used for generate-only batches.

Expand Down Expand Up @@ -200,6 +202,8 @@ def gqa_attention_kv_stage1(
attn = tl.dot(q, k.trans()) # [N, seq_block]
attn = attn.to(tl.float32)
attn *= SCALE
if LOGIT_CAP is not None:
attn = LOGIT_CAP * tanh(attn / LOGIT_CAP)
# Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
attn = tl.where(head_mask[:, None] * seq_mask[None, :], attn, float("-inf"))
# compute max_attn only when invalid attn values are masked out.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def insert_cached_attention(

if not source_attn_nodes:
# If there are no nodes for kv cache insertion found, return current graph
ad_logger.info("No source attention nodes found, skipping cache insertion.")
return egm

# Sanity check
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,90 @@ def test_gqa_op(device, dtype, n_heads, group_size, seq_len):
)


@pytest.mark.parametrize("logit_cap", [50.0])
@pytest.mark.parametrize("group_size", [1, 4])
@pytest.mark.parametrize("n_heads", [8])
@pytest.mark.parametrize("dtype", ["float16", "float32"])
@pytest.mark.parametrize("device", ["cuda"])
def test_gqa_op_with_logit_cap(device, dtype, n_heads, group_size, logit_cap):
# This test is for generation phase, so seq_len is 1.
seq_len = 1
BATCH_SIZE = 2
D_HEAD = 16
CACHE_SEQ_LEN = 8

dtype = getattr(torch, dtype)
n_kv_heads = n_heads // group_size

offset = 4 # some offset
input_positions = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int) + offset

q = torch.randn(BATCH_SIZE, seq_len, n_heads, D_HEAD, dtype=dtype, device=device)
k = torch.randn(BATCH_SIZE, seq_len, n_kv_heads, D_HEAD, dtype=dtype, device=device)
v = torch.randn(BATCH_SIZE, seq_len, n_kv_heads, D_HEAD, dtype=dtype, device=device)

# setup kv-cache
k_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device)
v_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device)

# Store k,v in cache for op
k_cache_op = k_cache.clone()
v_cache_op = v_cache.clone()

# run custom op
output = torch.ops.attention.fused_mha_with_cache(
q, k, v, input_positions, k_cache_op, v_cache_op, None, logit_cap
)

# for reference, we manually update the cache
k_cache[:, input_positions[0] : input_positions[0] + seq_len] = k
v_cache[:, input_positions[0] : input_positions[0] + seq_len] = v

k_cache_ref = torch.repeat_interleave(k_cache, group_size, dim=2) # [b,s,n,d]
v_cache_ref = torch.repeat_interleave(v_cache, group_size, dim=2) # [b,s,n,d]

# Reference implementation
q_ref = q.transpose(1, 2)
# up to `offset + 1`
k_ref = k_cache_ref[:, : offset + seq_len].transpose(1, 2)
v_ref = v_cache_ref[:, : offset + seq_len].transpose(1, 2)

scale = 1.0 / (D_HEAD**0.5)
attn = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale

if logit_cap is not None:
attn = logit_cap * torch.tanh(attn / logit_cap)

# For seq_len=1, there is no causal mask. We attend to all keys in cache up to current position.

attn = torch.nn.functional.softmax(attn, dim=-1)
ref_out = torch.matmul(attn, v_ref)

ref = ref_out.transpose(1, 2).contiguous().view(BATCH_SIZE, seq_len, n_heads * D_HEAD)

# Check that op output and reference are close
assert torch.allclose(
ref.cpu().to(torch.float32),
output.cpu().to(torch.float32),
atol=1e-2,
rtol=1e-2,
)

# Check that cache is updated correctly by the op
assert torch.allclose(
k_cache_op.cpu(),
k_cache.cpu(),
atol=1e-5,
rtol=1e-5,
)
assert torch.allclose(
v_cache_op.cpu(),
v_cache.cpu(),
atol=1e-5,
rtol=1e-5,
)


@pytest.mark.parametrize("num_generate_ratio", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("max_seq_len", [0, 1, 16])
@pytest.mark.parametrize("group_size", [1, 4])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
None,
1.0,
1.0,
None, # logit_cap
)

ref = torch.nn.functional.scaled_dot_product_attention(
Expand Down Expand Up @@ -234,6 +235,7 @@ def test_flashinfer_attention_op_decode(
None,
1.0,
1.0,
None, # logit_cap
)

assert torch.allclose(
Expand Down Expand Up @@ -350,6 +352,7 @@ def test_flashinfer_attention_context_and_generate(
None,
1.0,
1.0,
None, # logit_cap
)

# Generate reference outputs
Expand Down Expand Up @@ -425,6 +428,7 @@ def test_flashinfer_attention_context_and_generate(
None,
1.0,
1.0,
None, # logit_cap
)

# Generate reference outputs
Expand Down Expand Up @@ -534,6 +538,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
None,
1.0,
1.0,
None, # logit_cap
)

# Generate ref
Expand Down Expand Up @@ -681,6 +686,7 @@ def test_flashinfer_attention_with_fp8_cache(
None,
K_SCALE,
V_SCALE,
None, # logit_cap
)

y = flashinfer_output.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
Expand Down Expand Up @@ -778,6 +784,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
None,
1.0,
1.0,
None, # logit_cap
)

# Compute reference
Expand Down Expand Up @@ -861,6 +868,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
None,
1.0,
1.0,
None, # logit_cap
)

# Compute reference
Expand All @@ -886,3 +894,114 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
atol=1e-2,
rtol=1e-2,
)


@pytest.mark.parametrize("seq_length", [64])
@pytest.mark.parametrize("n_heads", [8])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("logit_cap", [10.0, 30.0])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("device", ["cuda"])
def test_flashinfer_attention_op_context_with_logit_cap(
seq_length, n_heads, batch_size, logit_cap, dtype, device
):
"""
Tests the context phase of flashinfer attention with logit soft-capping.
"""
D_HEAD = 64
MAX_SEQ_LEN = 2048
MAX_BATCH_SIZE = 32
DTYPE = dtype
BATCH_SIZE = batch_size
N_HEADS = n_heads
SEQ_LEN = seq_length

# metadata
seq_len_tensor = torch.tensor([SEQ_LEN] * BATCH_SIZE, dtype=torch.int32, device=device)
offsets = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int)

qo_indptr = torch.cat(
(torch.zeros_like(seq_len_tensor[:1]), torch.cumsum(seq_len_tensor, 0))
).to(torch.int32)
paged_kv_indptr = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device)
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
paged_kv_last_page_len = offsets + seq_len_tensor

# Q,K,V are computed using GEMM.
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
v = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)

# Setup KV Cache. KV cache is empty, context phase
k_cache = torch.zeros(
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
)
v_cache = torch.zeros(
(MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device
)

# make sure planner is initialized
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)

batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q,
k,
v,
# METADATA
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,
# BUFFERS
workspace,
# CONSTANTS
None,
1.0,
1.0,
logit_cap,
)

# Reference implementation
q_ref = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2)
k_ref = k.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2)
v_ref = v.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2)

scale = D_HEAD**-0.5
logits = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale

# Apply logit softcapping
if logit_cap > 0.0:
logits = logit_cap * torch.tanh(logits / logit_cap)

# Apply causal mask
causal_mask = torch.triu(
torch.ones(SEQ_LEN, SEQ_LEN, device=device, dtype=torch.bool), diagonal=1
)
logits.masked_fill_(causal_mask, -float("inf"))

# Apply softmax
attn_weights = torch.softmax(logits, dim=-1).to(v_ref.dtype)

# Compute output
ref = (attn_weights @ v_ref).transpose(1, 2).reshape(BATCH_SIZE, SEQ_LEN, -1)

assert torch.allclose(
flashinfer_output.cpu().to(torch.float32),
ref.cpu().to(torch.float32),
atol=1e-2,
rtol=1e-2,
)