Skip to content

Commit 4fe828c

Browse files
committed
[Model] Update default params
1 parent 7af5591 commit 4fe828c

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

native_sparse_attention/configuration_nsa.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
num_hidden_layers: int = 24,
1717
num_heads: int = 64,
1818
num_kv_heads: int = 4,
19+
head_dim: int = 64,
1920
qkv_bias: bool = False,
2021
block_size: int = 64,
2122
block_counts: Optional[int] = 16,
@@ -43,6 +44,7 @@ def __init__(
4344
self.num_hidden_layers = num_hidden_layers
4445
self.num_heads = num_heads
4546
self.num_kv_heads = num_kv_heads
47+
self.head_dim = head_dim
4648
self.qkv_bias = qkv_bias
4749
self.block_size = block_size
4850
self.block_counts = block_counts

native_sparse_attention/modeling_nsa.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ class NativeSparseAttention(nn.Module):
3535
def __init__(
3636
self,
3737
hidden_size: int = 2048,
38-
num_heads: int = 32,
39-
num_kv_heads: Optional[int] = None,
38+
num_heads: int = 64,
39+
num_kv_heads: Optional[int] = 4,
40+
head_dim: int = 64,
4041
qkv_bias: bool = False,
4142
block_size: Optional[int] = 64,
4243
block_counts: Optional[Union[torch.LongTensor, int]] = 16,
@@ -54,8 +55,7 @@ def __init__(
5455
else:
5556
self.num_kv_heads = num_kv_heads
5657
self.num_kv_groups = num_heads // self.num_kv_heads
57-
self.head_dim = self.hidden_size // self.num_heads
58-
self.kv_dim = self.num_kv_heads * self.head_dim
58+
self.head_dim = head_dim
5959
self.kv_dim = self.num_kv_heads * self.head_dim
6060
self.qkv_bias = qkv_bias
6161

@@ -66,11 +66,11 @@ def __init__(
6666
self.max_position_embeddings = max_position_embeddings
6767
self.layer_idx = layer_idx
6868

69-
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
69+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias)
7070
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
7171
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
7272
self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False)
73-
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
73+
self.o_proj = nn.Linear(self.kv_dim, self.hidden_size, bias=False)
7474

7575
self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
7676

@@ -128,7 +128,7 @@ def forward(
128128
k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
129129
v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
130130

131-
o = parallel_nsa_with_compression(
131+
o, _ = parallel_nsa_with_compression(
132132
q=q,
133133
k=k,
134134
v=v,
@@ -138,7 +138,6 @@ def forward(
138138
block_size=self.block_size,
139139
block_counts=self.block_counts,
140140
window_size=self.window_size,
141-
cu_seqlens=cu_seqlens,
142141
head_first=False
143142
)
144143
o = o.reshape(batch_size, seq_len, -1)

0 commit comments

Comments
 (0)