Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2672,15 +2672,15 @@ def xqa_batch_decode_with_kv_cache(
query_new,
k_cache,
v_cache,
k_cache_sf,
v_cache_sf,
block_tables,
seq_lens_new,
out_4d,
scratch,
semaphore,
num_kv_heads,
page_size,
k_sf_cache=k_cache_sf,
v_sf_cache=v_cache_sf,
sinks=sinks_new,
q_scale=q_scale_value,
kv_scale=kv_scale_value,
Expand Down
20 changes: 14 additions & 6 deletions flashinfer/norm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import functools
import os
from typing import Optional
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -62,11 +62,19 @@ def get_norm_module():


def _normalize_scale_tensor(
scale: torch.Tensor, ref_tensor: torch.Tensor
scale: Union[float, torch.Tensor], ref_tensor: torch.Tensor
) -> torch.Tensor:
"""Normalize quantization scale tensor to 1D shape (1,) on target device."""
"""Normalize quantization scale to 1D tensor of shape (1,) on target device."""
if not isinstance(scale, torch.Tensor):
raise TypeError(f"scale must be torch.Tensor, got {type(scale)}")
import warnings

warnings.warn(
"Passing scale as a float is deprecated and will be removed in a future "
"release. Use a torch.Tensor of shape (1,) instead.",
FutureWarning,
stacklevel=3,
)
scale = torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device)
if scale.device != ref_tensor.device:
scale = scale.to(ref_tensor.device)
if scale.dtype != torch.float32:
Expand Down Expand Up @@ -159,7 +167,7 @@ def rmsnorm_quant(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
scale: Union[float, torch.Tensor],
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
Expand Down Expand Up @@ -268,7 +276,7 @@ def fused_add_rmsnorm_quant(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
scale: Union[float, torch.Tensor],
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions flashinfer/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,6 @@ def xqa(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_sf_cache: Optional[torch.Tensor],
v_sf_cache: Optional[torch.Tensor],
page_table: torch.Tensor,
seq_lens: torch.Tensor,
output: torch.Tensor,
Expand All @@ -174,6 +172,9 @@ def xqa(
rcp_out_scale: float = 1.0,
q_seq_len: int = 1,
mask: Optional[torch.Tensor] = None,
*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does these parameters need to be keyword only?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's optional feature (guessing)

to that end the rationale is documented (end of this page)

https://github.com/flashinfer-ai/flashinfer/blob/main/CONTRIBUTING.md

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd like to put a * as soon as basic feature are done in the api. the extra things that pile on later passed positionally just gets worse and worse for api stability

imo positional args shouldn't exceed 10, or it becomes harder to maintain

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great questions! keep them coming

k_sf_cache: Optional[torch.Tensor] = None,
v_sf_cache: Optional[torch.Tensor] = None,
) -> None:
r"""Apply attention with paged KV cache using XQA kernel.
Parameters
Expand Down
Loading