From e35c19efb6c75e98a2c1f35c97fc1aed31439255 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Thu, 19 Mar 2026 19:04:43 -0700 Subject: [PATCH 1/2] proposed api fixes --- flashinfer/decode.py | 4 ++-- flashinfer/norm/__init__.py | 20 ++++++++++++++------ flashinfer/xqa.py | 5 +++-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 04bb7a1112..cc88a7b966 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2672,8 +2672,6 @@ def xqa_batch_decode_with_kv_cache( query_new, k_cache, v_cache, - k_cache_sf, - v_cache_sf, block_tables, seq_lens_new, out_4d, @@ -2681,6 +2679,8 @@ def xqa_batch_decode_with_kv_cache( semaphore, num_kv_heads, page_size, + k_sf_cache=k_cache_sf, + v_sf_cache=v_cache_sf, sinks=sinks_new, q_scale=q_scale_value, kv_scale=kv_scale_value, diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index af16a5d7be..b5dd61553c 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -26,7 +26,7 @@ import functools import os -from typing import Optional +from typing import Optional, Union import torch @@ -62,11 +62,19 @@ def get_norm_module(): def _normalize_scale_tensor( - scale: torch.Tensor, ref_tensor: torch.Tensor + scale: Union[float, torch.Tensor], ref_tensor: torch.Tensor ) -> torch.Tensor: - """Normalize quantization scale tensor to 1D shape (1,) on target device.""" + """Normalize quantization scale to 1D tensor of shape (1,) on target device.""" if not isinstance(scale, torch.Tensor): - raise TypeError(f"scale must be torch.Tensor, got {type(scale)}") + import warnings + + warnings.warn( + "Passing scale as a float is deprecated and will be removed in a future " + "release. Use a torch.Tensor of shape (1,) instead.", + FutureWarning, + stacklevel=3, + ) + scale = torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device) if scale.device != ref_tensor.device: scale = scale.to(ref_tensor.device) if scale.dtype != torch.float32: @@ -159,7 +167,7 @@ def rmsnorm_quant( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor, + scale: Union[float, torch.Tensor], eps: float = 1e-6, enable_pdl: Optional[bool] = None, ) -> None: @@ -268,7 +276,7 @@ def fused_add_rmsnorm_quant( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor, + scale: Union[float, torch.Tensor], eps: float = 1e-6, enable_pdl: Optional[bool] = None, ) -> None: diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index a447107c1d..f11944c5e2 100755 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -155,8 +155,6 @@ def xqa( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, - k_sf_cache: Optional[torch.Tensor], - v_sf_cache: Optional[torch.Tensor], page_table: torch.Tensor, seq_lens: torch.Tensor, output: torch.Tensor, @@ -174,6 +172,9 @@ def xqa( rcp_out_scale: float = 1.0, q_seq_len: int = 1, mask: Optional[torch.Tensor] = None, + *, + k_sf_cache: Optional[torch.Tensor] = None, + v_sf_cache: Optional[torch.Tensor] = None, ) -> None: r"""Apply attention with paged KV cache using XQA kernel. Parameters From 1c64dee4aeec4cff85ef3db61df34b010bd2271c Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Fri, 20 Mar 2026 10:55:04 -0700 Subject: [PATCH 2/2] address gemini --- flashinfer/norm/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index b5dd61553c..c04781e558 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -26,6 +26,7 @@ import functools import os +import warnings from typing import Optional, Union import torch @@ -66,8 +67,6 @@ def _normalize_scale_tensor( ) -> torch.Tensor: """Normalize quantization scale to 1D tensor of shape (1,) on target device.""" if not isinstance(scale, torch.Tensor): - import warnings - warnings.warn( "Passing scale as a float is deprecated and will be removed in a future " "release. Use a torch.Tensor of shape (1,) instead.",