diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 04bb7a1112..cc88a7b966 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2672,8 +2672,6 @@ def xqa_batch_decode_with_kv_cache( query_new, k_cache, v_cache, - k_cache_sf, - v_cache_sf, block_tables, seq_lens_new, out_4d, @@ -2681,6 +2679,8 @@ def xqa_batch_decode_with_kv_cache( semaphore, num_kv_heads, page_size, + k_sf_cache=k_cache_sf, + v_sf_cache=v_cache_sf, sinks=sinks_new, q_scale=q_scale_value, kv_scale=kv_scale_value, diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index af16a5d7be..c04781e558 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -26,7 +26,8 @@ import functools import os -from typing import Optional +import warnings +from typing import Optional, Union import torch @@ -62,11 +63,17 @@ def get_norm_module(): def _normalize_scale_tensor( - scale: torch.Tensor, ref_tensor: torch.Tensor + scale: Union[float, torch.Tensor], ref_tensor: torch.Tensor ) -> torch.Tensor: - """Normalize quantization scale tensor to 1D shape (1,) on target device.""" + """Normalize quantization scale to 1D tensor of shape (1,) on target device.""" if not isinstance(scale, torch.Tensor): - raise TypeError(f"scale must be torch.Tensor, got {type(scale)}") + warnings.warn( + "Passing scale as a float is deprecated and will be removed in a future " + "release. Use a torch.Tensor of shape (1,) instead.", + FutureWarning, + stacklevel=3, + ) + scale = torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device) if scale.device != ref_tensor.device: scale = scale.to(ref_tensor.device) if scale.dtype != torch.float32: @@ -159,7 +166,7 @@ def rmsnorm_quant( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor, + scale: Union[float, torch.Tensor], eps: float = 1e-6, enable_pdl: Optional[bool] = None, ) -> None: @@ -268,7 +275,7 @@ def fused_add_rmsnorm_quant( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor, + scale: Union[float, torch.Tensor], eps: float = 1e-6, enable_pdl: Optional[bool] = None, ) -> None: diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index a447107c1d..f11944c5e2 100755 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -155,8 +155,6 @@ def xqa( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, - k_sf_cache: Optional[torch.Tensor], - v_sf_cache: Optional[torch.Tensor], page_table: torch.Tensor, seq_lens: torch.Tensor, output: torch.Tensor, @@ -174,6 +172,9 @@ def xqa( rcp_out_scale: float = 1.0, q_seq_len: int = 1, mask: Optional[torch.Tensor] = None, + *, + k_sf_cache: Optional[torch.Tensor] = None, + v_sf_cache: Optional[torch.Tensor] = None, ) -> None: r"""Apply attention with paged KV cache using XQA kernel. Parameters