Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash_sparse_attn/ops/cute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("fa4")
__version__ = version("flash-sparse-attn")
except PackageNotFoundError:
__version__ = "0.0.0"

Expand All @@ -14,7 +14,7 @@
flash_attn_varlen_func,
)

from flash_attn.cute.cute_dsl_utils import cute_compile_patched
from flash_sparse_attn.ops.cute.cute_dsl_utils import cute_compile_patched

# Patch cute.compile to optionally dump SASS
cute.compile = cute_compile_patched
Expand Down
2 changes: 1 addition & 1 deletion flash_sparse_attn/ops/cute/blackwell_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cutlass.cute.nvgpu import tcgen05
from cutlass._mlir.dialects import llvm

import flash_attn.cute.mma_sm100_desc as sm100_desc
import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc


@cute.jit
Expand Down
2 changes: 1 addition & 1 deletion flash_sparse_attn/ops/cute/block_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import cutlass.cute as cute
from cutlass import Int32, const_expr

from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK


@dataclass(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions flash_sparse_attn/ops/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from quack import copy_utils

# Import data structures from block_sparsity
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.named_barrier import NamedBarrierBwd
from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors
from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd


# NOTE [SM100 block-sparse empty tiles: mbarrier contract]
Expand Down
2 changes: 1 addition & 1 deletion flash_sparse_attn/ops/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import cutlass.cute as cute
import torch

from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor
from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor


def ceildiv(a: int, b: int) -> int:
Expand Down
6 changes: 3 additions & 3 deletions flash_sparse_attn/ops/cute/compute_block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import torch
from cutlass import Boolean, Int8, Int32, const_expr

from flash_attn.cute.block_sparsity import (
from flash_sparse_attn.ops.cute.block_sparsity import (
BlockSparseTensors,
BlockSparseTensorsTorch,
to_cute_block_sparse_tensors,
)
from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK


class BlockSparsityKernel:
Expand Down
14 changes: 7 additions & 7 deletions flash_sparse_attn/ops/cute/flash_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import cutlass.utils as utils_basic

from quack import layout_utils
from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils
from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned
from flash_sparse_attn.ops.cute import utils
from flash_sparse_attn.ops.cute.mask import AttentionMask
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK
from quack.cute_dsl_utils import ParamsBase
from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors


class FlashAttentionBackwardSm80:
Expand Down
10 changes: 5 additions & 5 deletions flash_sparse_attn/ops/cute/flash_bwd_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from quack import layout_utils
from quack import sm90_utils

from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_sparse_attn.ops.cute import utils
from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned
from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK
import cutlass.cute.nvgpu.tcgen05 as tcgen05
from quack.cute_dsl_utils import ParamsBase
from flash_attn.cute.tile_scheduler import (
from flash_sparse_attn.ops.cute.tile_scheduler import (
SingleTileScheduler,
SingleTileVarlenScheduler,
TileSchedulerArguments,
Expand Down
6 changes: 3 additions & 3 deletions flash_sparse_attn/ops/cute/flash_bwd_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

from quack import copy_utils, layout_utils

from flash_attn.cute import utils
from flash_attn.cute.seqlen_info import SeqlenInfo
from flash_sparse_attn.ops.cute import utils
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo
from quack.cute_dsl_utils import ParamsBase
from flash_attn.cute.tile_scheduler import (
from flash_sparse_attn.ops.cute.tile_scheduler import (
SingleTileScheduler,
SingleTileVarlenScheduler,
TileSchedulerArguments,
Expand Down
28 changes: 14 additions & 14 deletions flash_sparse_attn/ops/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,27 @@

import quack.activation
from quack import layout_utils
from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import copy_utils
from flash_attn.cute import pipeline
from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.block_info import BlockInfo
from flash_sparse_attn.ops.cute import utils
from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned
from flash_sparse_attn.ops.cute import copy_utils
from flash_sparse_attn.ops.cute import pipeline
from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa
from flash_sparse_attn.ops.cute.mask import AttentionMask
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK
from flash_sparse_attn.ops.cute.block_info import BlockInfo
from quack.cute_dsl_utils import ParamsBase
from flash_attn.cute.tile_scheduler import (
from flash_sparse_attn.ops.cute.tile_scheduler import (
TileSchedulerArguments,
SingleTileScheduler,
SingleTileLPTBwdScheduler, # noqa
SingleTileVarlenScheduler,
)

from flash_attn.cute import barrier
from flash_attn.cute.named_barrier import NamedBarrierBwdSm100
from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.block_sparse_utils import (
from flash_sparse_attn.ops.cute import barrier
from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100
from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors
from flash_sparse_attn.ops.cute.block_sparse_utils import (
get_total_q_block_count_bwd,
get_block_sparse_iteration_info_bwd,
get_m_block_from_iter_bwd,
Expand Down
2 changes: 1 addition & 1 deletion flash_sparse_attn/ops/cute/flash_bwd_sm120.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cutlass
import cutlass.utils as utils_basic

from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80
from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80


class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80):
Expand Down
24 changes: 12 additions & 12 deletions flash_sparse_attn/ops/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
from quack import sm90_utils
from quack.sm90_utils import gemm_zero_init, gemm_w_idx

from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.block_info import BlockInfo
from flash_attn.cute import pipeline
from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned
from flash_sparse_attn.ops.cute import utils
from flash_sparse_attn.ops.cute.mask import AttentionMask
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK
from flash_sparse_attn.ops.cute.block_info import BlockInfo
from flash_sparse_attn.ops.cute import pipeline
from quack.cute_dsl_utils import ParamsBase
from flash_attn.cute.tile_scheduler import (
from flash_sparse_attn.ops.cute.tile_scheduler import (
TileSchedulerArguments,
SingleTileScheduler,
SingleTileLPTBwdScheduler,
SingleTileVarlenScheduler,
)
from flash_attn.cute import barrier
from flash_attn.cute.named_barrier import NamedBarrierBwd
from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.block_sparse_utils import (
from flash_sparse_attn.ops.cute import barrier
from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd
from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors
from flash_sparse_attn.ops.cute.block_sparse_utils import (
get_total_q_block_count_bwd,
produce_block_sparse_q_loads_bwd_sm90,
consume_block_sparse_mma_bwd_sm90,
Expand Down
24 changes: 12 additions & 12 deletions flash_sparse_attn/ops/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@
from quack import copy_utils
from quack import layout_utils

from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.softmax import Softmax
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.block_info import BlockInfo
from flash_attn.cute.pack_gqa import PackGQA
from flash_attn.cute.named_barrier import NamedBarrierFwd
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils
from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned
from flash_sparse_attn.ops.cute import utils
from flash_sparse_attn.ops.cute.mask import AttentionMask
from flash_sparse_attn.ops.cute.softmax import Softmax
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK
from flash_sparse_attn.ops.cute.block_info import BlockInfo
from flash_sparse_attn.ops.cute.pack_gqa import PackGQA
from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd
from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors
from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments


class FlashAttentionForwardBase:
Expand Down Expand Up @@ -1190,6 +1190,6 @@ def load_K_next():
# SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility
def __getattr__(name):
if name == "FlashAttentionForwardSm90":
from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90
from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90
return FlashAttentionForwardSm90
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
6 changes: 3 additions & 3 deletions flash_sparse_attn/ops/cute/flash_fwd_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from cutlass.cute.nvgpu import cpasync
from cutlass import Float32, Int32, Boolean, const_expr

from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute.seqlen_info import SeqlenInfo
from flash_sparse_attn.ops.cute import utils
from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo
from cutlass.cute import FastDivmodDivisor


Expand Down
30 changes: 15 additions & 15 deletions flash_sparse_attn/ops/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,28 @@

from quack import copy_utils, layout_utils

from flash_attn.cute.paged_kv import PagedKVManager
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
import flash_attn.cute.pipeline as pipeline_custom
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.block_info import BlockInfo
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.block_sparse_utils import (
from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager
from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned
from flash_sparse_attn.ops.cute import utils
import flash_sparse_attn.ops.cute.pipeline as pipeline_custom
from flash_sparse_attn.ops.cute.mask import AttentionMask
from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK
from flash_sparse_attn.ops.cute.block_info import BlockInfo
from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors
from flash_sparse_attn.ops.cute.block_sparse_utils import (
get_total_block_count,
produce_block_sparse_loads_sm100,
softmax_block_sparse_sm100,
handle_block_sparse_empty_tile_correction_sm100,
)
from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout
from flash_attn.cute import mma_sm100_desc as sm100_desc
from flash_attn.cute import blackwell_helpers as sm100_utils
from flash_attn.cute.named_barrier import NamedBarrierFwdSm100
from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout
from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc
from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils
from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100
from cutlass.cute import FastDivmodDivisor
from quack.cute_dsl_utils import ParamsBase
from flash_attn.cute.tile_scheduler import (
from flash_sparse_attn.ops.cute.tile_scheduler import (
TileSchedulerArguments,
SingleTileScheduler,
StaticPersistentTileScheduler,
Expand Down
2 changes: 1 addition & 1 deletion flash_sparse_attn/ops/cute/flash_fwd_sm120.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cutlass
import cutlass.utils as utils_basic

from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80
from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80


class FlashAttentionForwardSm120(FlashAttentionForwardSm80):
Expand Down
28 changes: 14 additions & 14 deletions flash_sparse_attn/ops/cute/flash_fwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,31 @@
from quack import layout_utils
from quack import sm90_utils

from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.softmax import Softmax, apply_score_mod_inner
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.block_info import BlockInfo
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.block_sparse_utils import (
from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned
from flash_sparse_attn.ops.cute import utils
from flash_sparse_attn.ops.cute.mask import AttentionMask
from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner
from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK
from flash_sparse_attn.ops.cute.block_info import BlockInfo
from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors
from flash_sparse_attn.ops.cute.block_sparse_utils import (
produce_block_sparse_loads,
consume_block_sparse_loads,
)
from flash_attn.cute import pipeline as pipeline_custom
from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom
from flash_attn.cute.paged_kv import PagedKVManager
from flash_attn.cute.named_barrier import NamedBarrierFwd
from flash_sparse_attn.ops.cute import pipeline as pipeline_custom
from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom
from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager
from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd
from quack.cute_dsl_utils import ParamsBase
from flash_attn.cute.tile_scheduler import (
from flash_sparse_attn.ops.cute.tile_scheduler import (
TileSchedulerArguments,
SingleTileScheduler,
SingleTileLPTScheduler,
SingleTileVarlenScheduler,
)
from cutlass.cute import FastDivmodDivisor

from flash_attn.cute.flash_fwd import FlashAttentionForwardBase
from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase


class FlashAttentionForwardSm90(FlashAttentionForwardBase):
Expand Down
Loading
Loading