diff --git a/flashinfer/attention.py b/flashinfer/attention.py index 23c5b1f84..2c8413f49 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -130,6 +130,8 @@ def run( kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, profiler_buffer: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if profiler_buffer is None: @@ -147,6 +149,8 @@ def run( head_dim_qk = q.shape[2] if self._sm_scale is None: self._sm_scale = 1.0 / math.sqrt(head_dim_qk) + if k_scale is not None: + self._sm_scale *= k_scale # profiler_buffer is optional profiler_args = (profiler_buffer,) if self._use_profiler else () @@ -169,5 +173,11 @@ def run( self._sm_scale, *profiler_args, ) + if v_scale is not None: + # TODO(Zihao): fused into kernel + if out.itemsize == 1: + out = (out.to(float) * v_scale).to(out.dtype) + else: + out *= v_scale return out, lse