diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py index 18452d3b417..9b018f5a19f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py @@ -554,6 +554,7 @@ def fused_mha_with_cache( k_cache: torch.Tensor, v_cache: torch.Tensor, freqs_cis: Optional[torch.Tensor], + logit_cap: Optional[float] = None, ) -> torch.Tensor: """Fused MHA with cache that takes raw input from q, k, v GEMMs.""" # b, s info @@ -593,6 +594,7 @@ def fused_mha_fake( k_cache: torch.Tensor, v_cache: torch.Tensor, freqs_cis: torch.Tensor, + logit_cap: Optional[float] = None, ): return torch.empty_like(q.contiguous()) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 414039a5065..82749e72fb0 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -37,6 +37,7 @@ class PlanParams: q_dtype: torch.dtype kv_dtype: torch.dtype sm_scale: Optional[float] = None + logit_cap: Optional[float] = None causal: bool = True @@ -107,6 +108,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper): q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + logits_soft_cap=plan_params.logit_cap, ) # we want to plan during warm-up of cuda graph capture to ensure we have the plan cached @@ -143,6 +145,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper): q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, + logits_soft_cap=plan_params.logit_cap, ) self.plan_params = plan_params @@ -250,6 +253,7 @@ def flashinfer_mha_with_cache( scale: Optional[float], k_scale: float, v_scale: float, + logit_cap: Optional[float], ) -> torch.Tensor: # reshape to standard [b*s, n_heads, head_dim] layout head_dim = k_cache.shape[-1] @@ -273,6 +277,7 @@ def flashinfer_mha_with_cache( q_dtype=q.dtype, kv_dtype=k_cache.dtype, sm_scale=scale, + logit_cap=logit_cap, ) # Assuming k_scale = v_scale = 1.0, we just have to cast k and v to fp8 before appending to kv cache @@ -327,6 +332,7 @@ def flashinfer_mha_with_cache_fake( scale: Optional[float], k_scale: float, v_scale: float, + logit_cap: Optional[float], ) -> torch.Tensor: return torch.empty_like(q.contiguous()) @@ -419,8 +425,17 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: ad_logger.warning("Provided scale is not a float. Using default scale instead.") scale = None + logit_cap = source_attn_node.kwargs.get("logit_cap", None) + if not (isinstance(logit_cap, float) or logit_cap is None): + ad_logger.debug("Provided logit_cap is not a float or None. Disabling soft-capping.") + logit_cap = None + elif isinstance(logit_cap, float) and logit_cap <= 0: + ad_logger.warning("Provided logit_cap is not positive. Disabling soft-capping.") + logit_cap = None + return [ scale, # softmax scale 1.0, # k_scale 1.0, # v_scale + logit_cap, ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py index 9a59a363dc4..af9ec93a672 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py @@ -2,6 +2,7 @@ import triton from triton import language as tl +from triton.language.extra.libdevice import tanh @triton.jit @@ -112,6 +113,7 @@ def gqa_attention_kv_stage1( V_D_HEAD: tl.constexpr, # Dimension of each key/value head SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim. HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores. + LOGIT_CAP: tl.constexpr = None, # softcapping introduced in the Gemma 2 paper ): """Attention kernel to be used for generate-only batches. @@ -200,6 +202,8 @@ def gqa_attention_kv_stage1( attn = tl.dot(q, k.trans()) # [N, seq_block] attn = attn.to(tl.float32) attn *= SCALE + if LOGIT_CAP is not None: + attn = LOGIT_CAP * tanh(attn / LOGIT_CAP) # Set to -inf attn values where mask is not set. This forces exp(attn) to 0. attn = tl.where(head_mask[:, None] * seq_mask[None, :], attn, float("-inf")) # compute max_attn only when invalid attn values are masked out. diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py index 97a4ef3fdac..c6af8788c88 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py @@ -68,6 +68,7 @@ def insert_cached_attention( if not source_attn_nodes: # If there are no nodes for kv cache insertion found, return current graph + ad_logger.info("No source attention nodes found, skipping cache insertion.") return egm # Sanity check diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py index cfc5ac1891c..a763852bd10 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py @@ -100,6 +100,90 @@ def test_gqa_op(device, dtype, n_heads, group_size, seq_len): ) +@pytest.mark.parametrize("logit_cap", [50.0]) +@pytest.mark.parametrize("group_size", [1, 4]) +@pytest.mark.parametrize("n_heads", [8]) +@pytest.mark.parametrize("dtype", ["float16", "float32"]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_gqa_op_with_logit_cap(device, dtype, n_heads, group_size, logit_cap): + # This test is for generation phase, so seq_len is 1. + seq_len = 1 + BATCH_SIZE = 2 + D_HEAD = 16 + CACHE_SEQ_LEN = 8 + + dtype = getattr(torch, dtype) + n_kv_heads = n_heads // group_size + + offset = 4 # some offset + input_positions = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int) + offset + + q = torch.randn(BATCH_SIZE, seq_len, n_heads, D_HEAD, dtype=dtype, device=device) + k = torch.randn(BATCH_SIZE, seq_len, n_kv_heads, D_HEAD, dtype=dtype, device=device) + v = torch.randn(BATCH_SIZE, seq_len, n_kv_heads, D_HEAD, dtype=dtype, device=device) + + # setup kv-cache + k_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device) + v_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device) + + # Store k,v in cache for op + k_cache_op = k_cache.clone() + v_cache_op = v_cache.clone() + + # run custom op + output = torch.ops.attention.fused_mha_with_cache( + q, k, v, input_positions, k_cache_op, v_cache_op, None, logit_cap + ) + + # for reference, we manually update the cache + k_cache[:, input_positions[0] : input_positions[0] + seq_len] = k + v_cache[:, input_positions[0] : input_positions[0] + seq_len] = v + + k_cache_ref = torch.repeat_interleave(k_cache, group_size, dim=2) # [b,s,n,d] + v_cache_ref = torch.repeat_interleave(v_cache, group_size, dim=2) # [b,s,n,d] + + # Reference implementation + q_ref = q.transpose(1, 2) + # up to `offset + 1` + k_ref = k_cache_ref[:, : offset + seq_len].transpose(1, 2) + v_ref = v_cache_ref[:, : offset + seq_len].transpose(1, 2) + + scale = 1.0 / (D_HEAD**0.5) + attn = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale + + if logit_cap is not None: + attn = logit_cap * torch.tanh(attn / logit_cap) + + # For seq_len=1, there is no causal mask. We attend to all keys in cache up to current position. + + attn = torch.nn.functional.softmax(attn, dim=-1) + ref_out = torch.matmul(attn, v_ref) + + ref = ref_out.transpose(1, 2).contiguous().view(BATCH_SIZE, seq_len, n_heads * D_HEAD) + + # Check that op output and reference are close + assert torch.allclose( + ref.cpu().to(torch.float32), + output.cpu().to(torch.float32), + atol=1e-2, + rtol=1e-2, + ) + + # Check that cache is updated correctly by the op + assert torch.allclose( + k_cache_op.cpu(), + k_cache.cpu(), + atol=1e-5, + rtol=1e-5, + ) + assert torch.allclose( + v_cache_op.cpu(), + v_cache.cpu(), + atol=1e-5, + rtol=1e-5, + ) + + @pytest.mark.parametrize("num_generate_ratio", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("max_seq_len", [0, 1, 16]) @pytest.mark.parametrize("group_size", [1, 4]) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py index 4872aef2210..cf1b41bc6e2 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py @@ -109,6 +109,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, None, 1.0, 1.0, + None, # logit_cap ) ref = torch.nn.functional.scaled_dot_product_attention( @@ -234,6 +235,7 @@ def test_flashinfer_attention_op_decode( None, 1.0, 1.0, + None, # logit_cap ) assert torch.allclose( @@ -350,6 +352,7 @@ def test_flashinfer_attention_context_and_generate( None, 1.0, 1.0, + None, # logit_cap ) # Generate reference outputs @@ -425,6 +428,7 @@ def test_flashinfer_attention_context_and_generate( None, 1.0, 1.0, + None, # logit_cap ) # Generate reference outputs @@ -534,6 +538,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty None, 1.0, 1.0, + None, # logit_cap ) # Generate ref @@ -681,6 +686,7 @@ def test_flashinfer_attention_with_fp8_cache( None, K_SCALE, V_SCALE, + None, # logit_cap ) y = flashinfer_output.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD) @@ -778,6 +784,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de None, 1.0, 1.0, + None, # logit_cap ) # Compute reference @@ -861,6 +868,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de None, 1.0, 1.0, + None, # logit_cap ) # Compute reference @@ -886,3 +894,114 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de atol=1e-2, rtol=1e-2, ) + + +@pytest.mark.parametrize("seq_length", [64]) +@pytest.mark.parametrize("n_heads", [8]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("logit_cap", [10.0, 30.0]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("device", ["cuda"]) +def test_flashinfer_attention_op_context_with_logit_cap( + seq_length, n_heads, batch_size, logit_cap, dtype, device +): + """ + Tests the context phase of flashinfer attention with logit soft-capping. + """ + D_HEAD = 64 + MAX_SEQ_LEN = 2048 + MAX_BATCH_SIZE = 32 + DTYPE = dtype + BATCH_SIZE = batch_size + N_HEADS = n_heads + SEQ_LEN = seq_length + + # metadata + seq_len_tensor = torch.tensor([SEQ_LEN] * BATCH_SIZE, dtype=torch.int32, device=device) + offsets = torch.zeros(BATCH_SIZE, device=device, dtype=torch.int) + + qo_indptr = torch.cat( + (torch.zeros_like(seq_len_tensor[:1]), torch.cumsum(seq_len_tensor, 0)) + ).to(torch.int32) + paged_kv_indptr = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) + paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device) + paged_kv_last_page_len = offsets + seq_len_tensor + + # Q,K,V are computed using GEMM. + q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) + k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) + v = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) + + # Setup KV Cache. KV cache is empty, context phase + k_cache = torch.zeros( + (MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device + ) + v_cache = torch.zeros( + (MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD), dtype=DTYPE, device=device + ) + + # make sure planner is initialized + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + _GlobalFlashInferPlanner.init_workspace(workspace) + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, + flashinfer.get_seq_lens( + paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1] + ), + BATCH_SIZE * SEQ_LEN, + ) + flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache( + # Q, K, V + q, + k, + v, + # METADATA + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + batch_indices, + positions, + # CACHES + k_cache, + v_cache, + # BUFFERS + workspace, + # CONSTANTS + None, + 1.0, + 1.0, + logit_cap, + ) + + # Reference implementation + q_ref = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2) + k_ref = k.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2) + v_ref = v.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2) + + scale = D_HEAD**-0.5 + logits = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale + + # Apply logit softcapping + if logit_cap > 0.0: + logits = logit_cap * torch.tanh(logits / logit_cap) + + # Apply causal mask + causal_mask = torch.triu( + torch.ones(SEQ_LEN, SEQ_LEN, device=device, dtype=torch.bool), diagonal=1 + ) + logits.masked_fill_(causal_mask, -float("inf")) + + # Apply softmax + attn_weights = torch.softmax(logits, dim=-1).to(v_ref.dtype) + + # Compute output + ref = (attn_weights @ v_ref).transpose(1, 2).reshape(BATCH_SIZE, SEQ_LEN, -1) + + assert torch.allclose( + flashinfer_output.cpu().to(torch.float32), + ref.cpu().to(torch.float32), + atol=1e-2, + rtol=1e-2, + )