Skip to content

Commit bec61f0

Browse files
committed
Add model and config definitions
1 parent 42158dc commit bec61f0

File tree

2 files changed

+593
-0
lines changed

2 files changed

+593
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from typing import Optional
4+
5+
from transformers.configuration_utils import PretrainedConfig
6+
7+
8+
class NSAConfig(PretrainedConfig):
9+
10+
model_type = 'nsa'
11+
keys_to_ignore_at_inference = ['past_key_values']
12+
13+
def __init__(
14+
self,
15+
hidden_size: int = 2048,
16+
num_hidden_layers: int = 24,
17+
num_heads: int = 64,
18+
num_kv_heads: int = 4,
19+
qkv_bias: bool = False,
20+
block_size: int = 64,
21+
block_counts: Optional[int] = 16,
22+
window_size: Optional[int] = 512,
23+
rope_theta: Optional[float] = 10000.,
24+
max_position_embeddings: int = 2048,
25+
hidden_ratio: Optional[int] = 4,
26+
intermediate_size: Optional[int] = None,
27+
hidden_act: str = "swish",
28+
initializer_range: float = 0.006,
29+
elementwise_affine: Optional[bool] = True,
30+
norm_eps: float = 1e-6,
31+
use_cache: bool = True,
32+
pad_token_id: int = None,
33+
bos_token_id: int = 1,
34+
eos_token_id: int = 2,
35+
tie_word_embeddings: bool = False,
36+
fuse_norm: bool = True,
37+
fuse_swiglu: bool = True,
38+
fuse_cross_entropy: bool = True,
39+
vocab_size: int = 32000,
40+
**kwargs,
41+
):
42+
self.hidden_size = hidden_size
43+
self.num_hidden_layers = num_hidden_layers
44+
self.num_heads = num_heads
45+
self.num_kv_heads = num_kv_heads
46+
self.qkv_bias = qkv_bias
47+
self.block_size = block_size
48+
self.block_counts = block_counts
49+
self.window_size = window_size
50+
self.rope_theta = rope_theta
51+
self.max_position_embeddings = max_position_embeddings
52+
53+
self.hidden_ratio = hidden_ratio
54+
self.intermediate_size = intermediate_size
55+
self.hidden_act = hidden_act
56+
57+
self.initializer_range = initializer_range
58+
self.elementwise_affine = elementwise_affine
59+
self.norm_eps = norm_eps
60+
self.use_cache = use_cache
61+
62+
self.fuse_norm = fuse_norm
63+
self.fuse_swiglu = fuse_swiglu
64+
self.fuse_cross_entropy = fuse_cross_entropy
65+
self.vocab_size = vocab_size
66+
67+
super().__init__(
68+
pad_token_id=pad_token_id,
69+
bos_token_id=bos_token_id,
70+
eos_token_id=eos_token_id,
71+
tie_word_embeddings=tie_word_embeddings,
72+
**kwargs,
73+
)

0 commit comments

Comments
 (0)