diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 2e2885a7c..3bf2f47b5 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -3,7 +3,7 @@ import shutil from itertools import product from pathlib import Path -from typing import List, Tuple +from typing import List, Tuple, Iterator import torch import torch.version @@ -11,7 +11,6 @@ from .activation import act_func_def_str, gen_act_and_mul_module from .cascade import gen_cascade_module -from .comm.nvshmem import gen_nvshmem_module from .fp4_quantization import gen_fp4_quantization_sm100_module from .fused_moe import gen_cutlass_fused_moe_sm100_module from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module @@ -43,63 +42,61 @@ def gen_fa2( head_dim_vo: int, use_sliding_window: bool, use_logits_soft_cap: bool, - use_attention_sink: bool, -) -> List[JitSpec]: +) -> Iterator[JitSpec]: if dtype_qo.itemsize == dtype_kv.itemsize and dtype_qo != dtype_kv: - return [] + return if dtype_qo.itemsize == 1: - return [] # fp8 tensor cores not supported in fa2 + return # fp8 tensor cores not supported in fa2 + + yield gen_single_prefill_module( + backend="fa2", + dtype_q=dtype_qo, + dtype_kv=dtype_kv, + dtype_o=dtype_qo, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + pos_encoding_mode=0, + use_sliding_window=use_sliding_window, + use_logits_soft_cap=use_logits_soft_cap, + use_fp16_qk_reduction=False, + ) - # TODO: support for AoT sink attention. + yield gen_batch_prefill_module( + backend="fa2", + dtype_q=dtype_qo, + dtype_kv=dtype_kv, + dtype_o=dtype_qo, + dtype_idx=torch.int32, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + pos_encoding_mode=0, + use_sliding_window=use_sliding_window, + use_logits_soft_cap=use_logits_soft_cap, + use_fp16_qk_reduction=False, + ) - return [ - gen_single_prefill_module( - backend="fa2", - dtype_q=dtype_qo, - dtype_kv=dtype_kv, - dtype_o=dtype_qo, - head_dim_qk=head_dim_qk, - head_dim_vo=head_dim_vo, - pos_encoding_mode=0, - use_sliding_window=use_sliding_window, - use_logits_soft_cap=use_logits_soft_cap, - use_fp16_qk_reduction=False, - ), - gen_batch_prefill_module( - backend="fa2", - dtype_q=dtype_qo, - dtype_kv=dtype_kv, - dtype_o=dtype_qo, - dtype_idx=torch.int32, - head_dim_qk=head_dim_qk, - head_dim_vo=head_dim_vo, - pos_encoding_mode=0, - use_sliding_window=use_sliding_window, - use_logits_soft_cap=use_logits_soft_cap, - use_fp16_qk_reduction=False, - ), - gen_single_decode_module( - dtype_q=dtype_qo, - dtype_kv=dtype_kv, - dtype_o=dtype_qo, - head_dim_qk=head_dim_qk, - head_dim_vo=head_dim_vo, - pos_encoding_mode=0, - use_sliding_window=use_sliding_window, - use_logits_soft_cap=use_logits_soft_cap, - ), - gen_batch_decode_module( - dtype_q=dtype_qo, - dtype_kv=dtype_kv, - dtype_o=dtype_qo, - dtype_idx=torch.int32, - head_dim_qk=head_dim_qk, - head_dim_vo=head_dim_vo, - pos_encoding_mode=0, - use_sliding_window=use_sliding_window, - use_logits_soft_cap=use_logits_soft_cap, - ), - ] + yield gen_single_decode_module( + dtype_q=dtype_qo, + dtype_kv=dtype_kv, + dtype_o=dtype_qo, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + pos_encoding_mode=0, + use_sliding_window=use_sliding_window, + use_logits_soft_cap=use_logits_soft_cap, + ) + + yield gen_batch_decode_module( + dtype_q=dtype_qo, + dtype_kv=dtype_kv, + dtype_o=dtype_qo, + dtype_idx=torch.int32, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + pos_encoding_mode=0, + use_sliding_window=use_sliding_window, + use_logits_soft_cap=use_logits_soft_cap, + ) def gen_fa3( @@ -110,47 +107,31 @@ def gen_fa3( head_dim_vo: int, use_sliding_window: bool, use_logits_soft_cap: bool, - use_attention_sink: bool, -) -> List[JitSpec]: +) -> Iterator[JitSpec]: if dtype_q != dtype_kv: - return [] # fa3 template do not support mixed precision + return # fa3 template do not support mixed precision if dtype_q.itemsize == 2: if dtype_q != dtype_o: - return [] # for fp16, dtype_o must be the same as dtype_q/dtype_kv + return # for fp16, dtype_o must be the same as dtype_q/dtype_kv if dtype_kv.itemsize == 1: if head_dim_qk == 192 or head_dim_qk == 64: - return [] # (192, 128) & (64, 64) not supported for fp8 yet. - - # TODO: support for AoT sink attention. + return # (192, 128) & (64, 64) not supported for fp8 yet. - return [ - gen_single_prefill_module( - backend="fa3", - dtype_q=dtype_q, - dtype_kv=dtype_kv, - dtype_o=dtype_o, - head_dim_qk=head_dim_qk, - head_dim_vo=head_dim_vo, - pos_encoding_mode=0, - use_sliding_window=use_sliding_window, - use_logits_soft_cap=use_logits_soft_cap, - use_fp16_qk_reduction=False, - ), - gen_batch_prefill_module( + for dtype_idx in [torch.int32, torch.int64]: + yield gen_batch_prefill_module( backend="fa3", dtype_q=dtype_q, dtype_kv=dtype_kv, dtype_o=dtype_o, - dtype_idx=torch.int32, + dtype_idx=dtype_idx, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=False, - ), - ] + ) def gen_attention( @@ -164,10 +145,9 @@ def gen_attention( has_sm100: bool, add_gemma: bool, add_oai_oss: bool, -) -> List[JitSpec]: +) -> Iterator[JitSpec]: head_dim_ckv = 512 head_dim_kpe = 64 - jit_specs: List[JitSpec] = [] # FA2 MHA / MQA / GQA for ( @@ -183,14 +163,13 @@ def gen_attention( use_sliding_window_, use_logits_soft_cap_, ): - jit_specs += gen_fa2( + yield from gen_fa2( dtype_qo=dtype_qo, dtype_kv=dtype_kv, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, - use_attention_sink=False, ) # FA3 MHA / MQA / GQA @@ -208,7 +187,7 @@ def gen_attention( use_sliding_window_, use_logits_soft_cap_, ): - jit_specs += gen_fa3( + yield from gen_fa3( dtype_q=dtype_qkv, dtype_kv=dtype_qkv, dtype_o=dtype_o, @@ -216,7 +195,6 @@ def gen_attention( head_dim_vo=head_dim_vo, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, - use_attention_sink=False, ) # Gemma @@ -230,14 +208,13 @@ def gen_attention( f16_dtype_ + f8_dtype_, [(True, True)], ): - jit_specs += gen_fa2( + yield from gen_fa2( dtype_qo=dtype_qo, dtype_kv=dtype_kv, head_dim_qk=256, head_dim_vo=256, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, - use_attention_sink=False, ) if has_sm90: for ( @@ -249,7 +226,7 @@ def gen_attention( f16_dtype_, [(True, True)], ): - jit_specs += gen_fa3( + yield from gen_fa3( dtype_q=dtype_qkv, dtype_kv=dtype_qkv, dtype_o=dtype_o, @@ -257,73 +234,61 @@ def gen_attention( head_dim_vo=256, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, - use_attention_sink=False, ) # OAI OSS if add_oai_oss: - for ( - dtype_qo, - dtype_kv, - use_sliding_window, - ) in product( - f16_dtype_, - f16_dtype_ + f8_dtype_, - [True], - ): - jit_specs += gen_fa2( - dtype_qo=dtype_qo, - dtype_kv=dtype_kv, - head_dim_qk=64, - head_dim_vo=64, - use_sliding_window=use_sliding_window, - use_logits_soft_cap=False, - use_attention_sink=True, - ) + from .jit.attention import gen_batch_prefill_attention_sink_module + + for dtype in f16_dtype_: + for backend in ["fa2", "fa3"]: + for use_swa in [True, False]: + yield gen_batch_prefill_attention_sink_module( + backend=backend, + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + dtype_idx=torch.int32, + head_dim_qk=64, + head_dim_vo=64, + pos_encoding_mode=0, + use_sliding_window=use_swa, + ) # fmha_cutlass_sm100a # NOTE: currently there's only one uri. if has_sm100: - jit_specs.append( - gen_fmha_cutlass_sm100a_module( - dtype_q=torch.bfloat16, - dtype_kv=torch.bfloat16, - dtype_o=torch.bfloat16, - dtype_idx=torch.int32, - head_dim_qk=128, - head_dim_vo=128, - pos_encoding_mode=0, - use_sliding_window=False, - use_logits_soft_cap=False, - ) + yield gen_fmha_cutlass_sm100a_module( + dtype_q=torch.bfloat16, + dtype_kv=torch.bfloat16, + dtype_o=torch.bfloat16, + dtype_idx=torch.int32, + head_dim_qk=128, + head_dim_vo=128, + pos_encoding_mode=0, + use_sliding_window=False, + use_logits_soft_cap=False, ) # MLA # NOTE: fp8 kv not supported in MLA - mla_backend_ = ["fa2"] - if has_sm90: - mla_backend_.append("fa3") + mla_backend_ = ["fa2"] + (["fa3"] if has_sm90 else []) for dtype_qo in f16_dtype_: - dtype_kv = dtype_qo for backend in mla_backend_: - jit_specs.append( - gen_batch_mla_module( - backend=backend, - dtype_q=dtype_qo, - dtype_kv=dtype_kv, - dtype_o=dtype_qo, - dtype_idx=torch.int32, - head_dim_ckv=head_dim_ckv, - head_dim_kpe=head_dim_kpe, - use_profiler=False, - ) + yield gen_batch_mla_module( + backend=backend, + dtype_q=dtype_qo, + dtype_kv=dtype_qo, + dtype_o=dtype_qo, + dtype_idx=torch.int32, + head_dim_ckv=head_dim_ckv, + head_dim_kpe=head_dim_kpe, + use_profiler=False, ) # MLA SM100 if has_sm100: - jit_specs.append(gen_mla_module()) - - return jit_specs + yield gen_mla_module() def gen_all_modules( @@ -344,17 +309,19 @@ def gen_all_modules( ) -> List[JitSpec]: jit_specs: List[JitSpec] = [] - jit_specs += gen_attention( - f16_dtype_, - f8_dtype_, - fa2_head_dim_, - fa3_head_dim_, - use_sliding_window_, - use_logits_soft_cap_, - has_sm90, - has_sm100, - add_gemma, - add_oai_oss, + jit_specs += list( + gen_attention( + f16_dtype_, + f8_dtype_, + fa2_head_dim_, + fa3_head_dim_, + use_sliding_window_, + use_logits_soft_cap_, + has_sm90, + has_sm100, + add_gemma, + add_oai_oss, + ) ) if add_act: @@ -371,24 +338,33 @@ def gen_all_modules( jit_specs.append(gen_gemm_sm100_module()) if add_comm: - from .comm import gen_trtllm_comm_module, gen_vllm_comm_module + try: + from .comm.nvshmem import gen_nvshmem_module - if has_sm100: - jit_specs.append(gen_trtllm_comm_module()) - jit_specs.append(gen_vllm_comm_module()) + jit_specs.append(gen_nvshmem_module()) + except ImportError: + pass + + try: + from .comm import gen_trtllm_comm_module, gen_vllm_comm_module + + if has_sm100: + jit_specs.append(gen_trtllm_comm_module()) + jit_specs.append(gen_vllm_comm_module()) + except ImportError: + pass if add_misc: jit_specs += [ gen_cascade_module(), gen_norm_module(), - gen_nvshmem_module(), gen_page_module(), gen_quantization_module(), gen_rope_module(), gen_sampling_module(), ] - if has_sm90: - jit_specs.append(get_trtllm_utils_spec()) + if has_sm90: + jit_specs.append(get_trtllm_utils_spec()) # dedup names = set() @@ -538,10 +514,10 @@ def main(): # True, ] add_comm = False - add_gemma = True + add_gemma = False add_oai_oss = True add_moe = False - add_act = True + add_act = False add_misc = True # Override diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py index 70441ef05..a26b7cee4 100644 --- a/flashinfer/jit/attention/__init__.py +++ b/flashinfer/jit/attention/__init__.py @@ -46,6 +46,10 @@ from .pytorch import get_single_decode_uri as get_single_decode_uri from .pytorch import get_single_prefill_uri as get_single_prefill_uri from .pytorch import trtllm_gen_fmha_module as trtllm_gen_fmha_module +from .pytorch import ( + gen_batch_prefill_attention_sink_module as gen_batch_prefill_attention_sink_module, + get_batch_prefill_attention_sink_uri as get_batch_prefill_attention_sink_uri, +) from .tvm import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding from .tvm import ( gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding, diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index eabeb9e70..31c6f40b8 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -388,6 +388,28 @@ def get_batch_prefill_uri( ) +def get_batch_prefill_attention_sink_uri( + backend: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, + use_sliding_window: bool, +) -> str: + return ( + f"batch_prefill_with_attention_sink_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" + f"use_swa_{use_sliding_window}_" + ("_sm90" if backend == "fa3" else "") + ) + + def get_batch_attention_uri( dtype_q: torch.dtype, dtype_kv: torch.dtype, @@ -856,6 +878,54 @@ def gen_batch_prefill_module( ) +def gen_batch_prefill_attention_sink_module( + backend: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, + use_sliding_window: bool, +) -> JitSpec: + from flashinfer.jit.attention.variants import attention_sink_decl + + uri = get_batch_prefill_attention_sink_uri( + backend, + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + use_sliding_window, + ) + + return gen_customize_batch_prefill_module( + backend, + uri, + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_qk, + head_dim_vo, + ["sink"], + ["float"], + ["sm_scale"], + ["double"], + "AttentionSink", + attention_sink_decl[backend], + pos_encoding_mode=pos_encoding_mode, + use_sliding_window=use_sliding_window, + use_logits_soft_cap=False, + use_fp16_qk_reduction=False, + fp8_enabled=False, + ) + + def gen_batch_attention_module( dtype_q: torch.dtype, dtype_kv: torch.dtype, diff --git a/flashinfer/jit/attention/variants.py b/flashinfer/jit/attention/variants.py new file mode 100644 index 000000000..16ee1c4f3 --- /dev/null +++ b/flashinfer/jit/attention/variants.py @@ -0,0 +1,169 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +attention_sink_fa2_decl = r""" +struct AttentionSink : AttentionVariantBase { + static constexpr bool use_softmax = true; + + uint32_t window_left, qo_len, kv_len; + float sm_scale_log2; + + // Create closure + template + __device__ __host__ AttentionSink(const Params& params, uint32_t batch_idx, + uint8_t* smem_ptr) { + qo_len = params.get_qo_len(batch_idx); + kv_len = params.get_kv_len(batch_idx); + window_left = (params.window_left >= 0) ? params.window_left : kv_len; + sm_scale_log2 = params.sm_scale * math::log2e; + } + + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + return (kv_idx + qo_len + window_left >= kv_len + qo_idx); + }) + + REGISTER_M_D_UPDATE(params, kv_tile_idx, qo_head_idx, m, d, scale, { + float log_sink = (kv_tile_idx == 0 && qo_head_idx < params.num_qo_heads) ? params.sink[qo_head_idx] * math::log2e : -math::inf; + float m_new = (log_sink > m) ? log_sink : m; + scale = math::ptx_exp2(m - m_new); + float d_new = math::ptx_exp2(log_sink - m_new) + d * scale; + // Update m and d + m = m_new; + d = d_new; + }) + + REGISTER_OUTPUT_TRANSFORM(params, output, batch_idx, qo_idx, qo_head_idx, m, d, scale, { + float d_rcp = (m != -math::inf) ? math::ptx_rcp(d) : 0.f; + return output * scale * d_rcp; + }); +}; +""" + +attention_sink_fa3_decl = r""" + +template +struct OnlineSoftmaxWithSink { + constexpr static float fill_value = -math::inf; + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum, scores_scale; + float sm_scale_log2; + float log_sink; + + CUTLASS_DEVICE OnlineSoftmaxWithSink(float sm_scale_log2, float log_sink) : sm_scale_log2(sm_scale_log2), log_sink(log_sink) { + clear(scores_scale); + }; + + __forceinline__ __device__ TensorT get_lse() const { return row_sum; } + + template + __forceinline__ __device__ void update(Tensor0& acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + + static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); + if constexpr (init) { + reduce_max(scores, row_max); + scale_apply_exp2(scores, row_max, sm_scale_log2); + reduce_sum(scores, row_sum); + } else { + // update row_max + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + reduce_max(scores, row_max); + // update scores_scale and scale row_sum +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = row_max(mi); + scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2); + row_sum(mi) *= scores_scale(mi); + } + // perform exp2 on scores + scale_apply_exp2(scores, row_max, sm_scale_log2); + // update row_sum + reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ void finalize(Tensor0& acc_s, float pv_scale = 1.f) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + // Note (Yilong): use pv_scale to dequantize the output + Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float m = row_max(mi) * sm_scale_log2; + float d = row_sum(mi); + + float m_new = (log_sink > m) ? log_sink : m; + float scale = math::ptx_exp2(m - m_new); + float d_new = math::ptx_exp2(log_sink - m_new) + d * scale; + + // Update m and d + row_max(mi) = m_new; + row_sum(mi) = d_new; + + scores_scale(mi) = pv_scale * scale / d_new; + row_sum(mi) = row_max(mi) + math::ptx_log2(d_new); + } + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1& acc_o) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale(mi); + } + } + }; +}; + +struct AttentionSink : AttentionVariantBase { + float sm_scale_log2; + float log_sink; + int qo_len, kv_len; + + // Init + template + __device__ __host__ AttentionSink(const MainloopParams& params, const BlockCoord& block_coord) { + sm_scale_log2 = params.additional_params.sm_scale * math::log2e; + auto [_, qo_head_idx, __, ___, ____, qo_len_, kv_len_, batch_idx] = + block_coord; + log_sink = params.additional_params.sink[qo_head_idx] * math::log2e; + + qo_len = qo_len_; + kv_len = kv_len_; + } + + template + __device__ auto GetAttentionUpdater() { + return OnlineSoftmaxWithSink(sm_scale_log2, log_sink); + } +}; +""" + +attention_sink_decl = { + "fa2": attention_sink_fa2_decl, + "fa3": attention_sink_fa3_decl, +} diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index e3f53e904..6759578a4 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1768,6 +1768,7 @@ def plan( window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap use_fp16_qk_reduction, + False, # use_attention_sink ) self._cached_module = get_batch_prefill_module( @@ -2599,6 +2600,7 @@ def plan( window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap use_fp16_qk_reduction, + False, # use_attention_sink ) if self._backend == "cutlass": self._cached_module = get_fmha_module(*get_module_args) diff --git a/tests/test_attention_sink.py b/tests/test_attention_sink.py index b04fbf993..254e4af23 100644 --- a/tests/test_attention_sink.py +++ b/tests/test_attention_sink.py @@ -22,162 +22,10 @@ import flashinfer from flashinfer.jit.utils import filename_safe_dtype_map +from flashinfer.jit.attention import gen_batch_prefill_attention_sink_module +from flashinfer.jit.attention.variants import attention_sink_decl from flashinfer.utils import is_sm90a_supported -attention_sink_fa2_decl = r""" -struct AttentionSink : AttentionVariantBase { - static constexpr bool use_softmax = true; - - uint32_t window_left, qo_len, kv_len; - float sm_scale_log2; - - // Create closure - template - __device__ __host__ AttentionSink(const Params& params, uint32_t batch_idx, - uint8_t* smem_ptr) { - qo_len = params.get_qo_len(batch_idx); - kv_len = params.get_kv_len(batch_idx); - window_left = (params.window_left >= 0) ? params.window_left : kv_len; - sm_scale_log2 = params.sm_scale * math::log2e; - } - - REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { - return (kv_idx + qo_len + window_left >= kv_len + qo_idx); - }) - - REGISTER_M_D_UPDATE(params, kv_tile_idx, qo_head_idx, m, d, scale, { - float log_sink = (kv_tile_idx == 0 && qo_head_idx < params.num_qo_heads) ? params.sink[qo_head_idx] * math::log2e : -math::inf; - float m_new = (log_sink > m) ? log_sink : m; - scale = math::ptx_exp2(m - m_new); - float d_new = math::ptx_exp2(log_sink - m_new) + d * scale; - // Update m and d - m = m_new; - d = d_new; - }) - - REGISTER_OUTPUT_TRANSFORM(params, output, batch_idx, qo_idx, qo_head_idx, m, d, scale, { - float d_rcp = (m != -math::inf) ? math::ptx_rcp(d) : 0.f; - return output * scale * d_rcp; - }); -}; -""" - -attention_sink_fa3_decl = r""" - -template -struct OnlineSoftmaxWithSink { - constexpr static float fill_value = -math::inf; - using TensorT = decltype(make_tensor(Shape>{})); - TensorT row_max, row_sum, scores_scale; - float sm_scale_log2; - float log_sink; - - CUTLASS_DEVICE OnlineSoftmaxWithSink(float sm_scale_log2, float log_sink) : sm_scale_log2(sm_scale_log2), log_sink(log_sink) { - clear(scores_scale); - }; - - __forceinline__ __device__ TensorT get_lse() const { return row_sum; } - - template - __forceinline__ __device__ void update(Tensor0& acc_s) { - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); - - static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); - if constexpr (init) { - reduce_max(scores, row_max); - scale_apply_exp2(scores, row_max, sm_scale_log2); - reduce_sum(scores, row_sum); - } else { - // update row_max - Tensor scores_max_prev = make_fragment_like(row_max); - cute::copy(row_max, scores_max_prev); - reduce_max(scores, row_max); - // update scores_scale and scale row_sum -#pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float scores_max_cur = row_max(mi); - scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2); - row_sum(mi) *= scores_scale(mi); - } - // perform exp2 on scores - scale_apply_exp2(scores, row_max, sm_scale_log2); - // update row_sum - reduce_sum(scores, row_sum); - } - }; - - template - __forceinline__ __device__ void finalize(Tensor0& acc_s, float pv_scale = 1.f) { - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - // Note (Yilong): use pv_scale to dequantize the output - Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - -#pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float m = row_max(mi) * sm_scale_log2; - float d = row_sum(mi); - - float m_new = (log_sink > m) ? log_sink : m; - float scale = math::ptx_exp2(m - m_new); - float d_new = math::ptx_exp2(log_sink - m_new) + d * scale; - - // Update m and d - row_max(mi) = m_new; - row_sum(mi) = d_new; - - scores_scale(mi) = pv_scale * scale / d_new; - row_sum(mi) = row_max(mi) + math::ptx_log2(d_new); - } - }; - - template - __forceinline__ __device__ void rescale_o(Tensor1& acc_o) { - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD); -#pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { -#pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scores_scale(mi); - } - } - }; -}; - -struct AttentionSink : AttentionVariantBase { - float sm_scale_log2; - float log_sink; - int qo_len, kv_len; - - // Init - template - __device__ __host__ AttentionSink(const MainloopParams& params, const BlockCoord& block_coord) { - sm_scale_log2 = params.additional_params.sm_scale * math::log2e; - auto [_, qo_head_idx, __, ___, ____, qo_len_, kv_len_, batch_idx] = - block_coord; - log_sink = params.additional_params.sink[qo_head_idx] * math::log2e; - - qo_len = qo_len_; - kv_len = kv_len_; - } - - template - __device__ auto GetAttentionUpdater() { - return OnlineSoftmaxWithSink(sm_scale_log2, log_sink); - } -}; -""" - -attention_sink_decl = { - "fa2": attention_sink_fa2_decl, - "fa3": attention_sink_fa3_decl, -} - @pytest.fixture(autouse=True, scope="module") def warmup_jit(): @@ -186,28 +34,19 @@ def warmup_jit(): for backend in ["fa2", "fa3"]: for use_swa in [True, False]: for head_dim in [128]: - jit_args = ( - f"batch_prefill_attention_sink_{filename_safe_dtype_map[dtype]}_swa_{use_swa}_{backend}", - dtype, - dtype, - dtype, - torch.int32, - head_dim, - head_dim, - ["sink"], - ["float"], - ["sm_scale"], - ["double"], - "AttentionSink", - attention_sink_decl[backend], - ) - jit_kwargs = { - "use_sliding_window": use_swa, - } - jit_spec = flashinfer.jit.gen_customize_batch_prefill_module( - backend, *jit_args, **jit_kwargs + jit_specs.append( + gen_batch_prefill_attention_sink_module( + backend=backend, + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + dtype_idx=torch.int32, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + pos_encoding_mode=0, + use_sliding_window=use_swa, + ) ) - jit_specs.append(jit_spec) flashinfer.jit.build_jit_specs(jit_specs) yield