@@ -35,8 +35,9 @@ class NativeSparseAttention(nn.Module):
3535 def __init__ (
3636 self ,
3737 hidden_size : int = 2048 ,
38- num_heads : int = 32 ,
39- num_kv_heads : Optional [int ] = None ,
38+ num_heads : int = 64 ,
39+ num_kv_heads : Optional [int ] = 4 ,
40+ head_dim : int = 64 ,
4041 qkv_bias : bool = False ,
4142 block_size : Optional [int ] = 64 ,
4243 block_counts : Optional [Union [torch .LongTensor , int ]] = 16 ,
@@ -54,8 +55,7 @@ def __init__(
5455 else :
5556 self .num_kv_heads = num_kv_heads
5657 self .num_kv_groups = num_heads // self .num_kv_heads
57- self .head_dim = self .hidden_size // self .num_heads
58- self .kv_dim = self .num_kv_heads * self .head_dim
58+ self .head_dim = head_dim
5959 self .kv_dim = self .num_kv_heads * self .head_dim
6060 self .qkv_bias = qkv_bias
6161
@@ -66,11 +66,11 @@ def __init__(
6666 self .max_position_embeddings = max_position_embeddings
6767 self .layer_idx = layer_idx
6868
69- self .q_proj = nn .Linear (self .hidden_size , self .hidden_size , bias = self .qkv_bias )
69+ self .q_proj = nn .Linear (self .hidden_size , self .num_heads * self . head_dim , bias = self .qkv_bias )
7070 self .k_proj = nn .Linear (self .hidden_size , self .kv_dim , bias = self .qkv_bias )
7171 self .v_proj = nn .Linear (self .hidden_size , self .kv_dim , bias = self .qkv_bias )
7272 self .g_proj = nn .Linear (self .hidden_size , self .num_heads * 3 , bias = False )
73- self .o_proj = nn .Linear (self .hidden_size , self .hidden_size , bias = False )
73+ self .o_proj = nn .Linear (self .kv_dim , self .hidden_size , bias = False )
7474
7575 self .rotary = RotaryEmbedding (dim = self .head_dim , base = self .rope_theta )
7676
@@ -128,7 +128,7 @@ def forward(
128128 k = rearrange (k , '... (h d) -> ... h d' , d = self .head_dim )
129129 v = rearrange (v , '... (h d) -> ... h d' , d = self .head_dim )
130130
131- o = parallel_nsa_with_compression (
131+ o , _ = parallel_nsa_with_compression (
132132 q = q ,
133133 k = k ,
134134 v = v ,
@@ -138,7 +138,6 @@ def forward(
138138 block_size = self .block_size ,
139139 block_counts = self .block_counts ,
140140 window_size = self .window_size ,
141- cu_seqlens = cu_seqlens ,
142141 head_first = False
143142 )
144143 o = o .reshape (batch_size , seq_len , - 1 )
0 commit comments