Skip to content

Commit d834161

Browse files
committed
Add mamba configs (#22)
1 parent c7c0183 commit d834161

File tree

7 files changed

+178
-2
lines changed

7 files changed

+178
-2
lines changed

3rdparty/flash-linear-attention

configs/mamba2_1B.json

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"bos_token_id": 1,
3+
"chunk_size": 256,
4+
"conv_kernel": 4,
5+
"eos_token_id": 2,
6+
"expand": 2,
7+
"fuse_cross_entropy": true,
8+
"fuse_norm": true,
9+
"head_dim": 64,
10+
"hidden_act": "silu",
11+
"hidden_size": 2048,
12+
"initializer_range": 0.02,
13+
"norm_eps": 1e-05,
14+
"model_type": "mamba2",
15+
"n_groups": 1,
16+
"num_hidden_layers": 48,
17+
"pad_token_id": 0,
18+
"rescale_prenorm_residual": true,
19+
"residual_in_fp32": true,
20+
"rms_norm": true,
21+
"state_size": 128,
22+
"tie_word_embeddings": false,
23+
"time_step_floor": 0.0001,
24+
"time_step_max": 0.1,
25+
"time_step_min": 0.001,
26+
"time_step_rank": 128,
27+
"transformers_version": "4.50.1",
28+
"use_bias": false,
29+
"use_cache": true,
30+
"use_conv_bias": true,
31+
"vocab_size": 32000
32+
}

configs/mamba2_340M.json

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"bos_token_id": 1,
3+
"chunk_size": 256,
4+
"conv_kernel": 4,
5+
"eos_token_id": 2,
6+
"expand": 2,
7+
"fuse_cross_entropy": true,
8+
"fuse_norm": true,
9+
"head_dim": 64,
10+
"hidden_act": "silu",
11+
"hidden_size": 1024,
12+
"initializer_range": 0.02,
13+
"norm_eps": 1e-05,
14+
"model_type": "mamba2",
15+
"n_groups": 1,
16+
"num_hidden_layers": 48,
17+
"pad_token_id": 0,
18+
"rescale_prenorm_residual": true,
19+
"residual_in_fp32": true,
20+
"rms_norm": true,
21+
"state_size": 128,
22+
"tie_word_embeddings": false,
23+
"time_step_floor": 0.0001,
24+
"time_step_max": 0.1,
25+
"time_step_min": 0.001,
26+
"time_step_rank": 128,
27+
"transformers_version": "4.50.1",
28+
"use_bias": false,
29+
"use_cache": true,
30+
"use_conv_bias": true,
31+
"vocab_size": 32000
32+
}

configs/mamba_1B.json

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"bos_token_id": 1,
3+
"conv_kernel": 4,
4+
"eos_token_id": 2,
5+
"expand": 2,
6+
"fuse_cross_entropy": true,
7+
"fuse_norm": true,
8+
"hidden_act": "silu",
9+
"hidden_size": 2048,
10+
"initializer_range": 0.02,
11+
"model_type": "mamba",
12+
"norm_eps": 1e-05,
13+
"num_hidden_layers": 48,
14+
"pad_token_id": 0,
15+
"rescale_prenorm_residual": false,
16+
"residual_in_fp32": false,
17+
"state_size": 16,
18+
"tie_word_embeddings": false,
19+
"time_step_floor": 0.0001,
20+
"time_step_init_scheme": "random",
21+
"time_step_max": 0.1,
22+
"time_step_min": 0.001,
23+
"time_step_rank": 128,
24+
"time_step_scale": 1.0,
25+
"transformers_version": "4.50.1",
26+
"use_bias": false,
27+
"use_cache": true,
28+
"use_conv_bias": true,
29+
"vocab_size": 32000
30+
}

configs/mamba_340M.json

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"bos_token_id": 1,
3+
"conv_kernel": 4,
4+
"eos_token_id": 2,
5+
"expand": 2,
6+
"fuse_cross_entropy": true,
7+
"fuse_norm": true,
8+
"hidden_act": "silu",
9+
"hidden_size": 1024,
10+
"initializer_range": 0.02,
11+
"model_type": "mamba",
12+
"norm_eps": 1e-05,
13+
"num_hidden_layers": 48,
14+
"pad_token_id": 0,
15+
"rescale_prenorm_residual": false,
16+
"residual_in_fp32": false,
17+
"state_size": 16,
18+
"tie_word_embeddings": false,
19+
"time_step_floor": 0.0001,
20+
"time_step_init_scheme": "random",
21+
"time_step_max": 0.1,
22+
"time_step_min": 0.001,
23+
"time_step_rank": 128,
24+
"time_step_scale": 1.0,
25+
"transformers_version": "4.50.1",
26+
"use_bias": false,
27+
"use_cache": true,
28+
"use_conv_bias": true,
29+
"vocab_size": 32000
30+
}

configs/samba_1B.json

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
{
2+
"attn": {
3+
"layers": [
4+
1,
5+
3,
6+
5,
7+
7,
8+
9,
9+
11,
10+
13,
11+
15,
12+
17
13+
],
14+
"num_heads": 18,
15+
"num_kv_heads": 18,
16+
"qkv_bias": false,
17+
"rope_theta": 10000.0,
18+
"window_size": 2048
19+
},
20+
"bos_token_id": 1,
21+
"conv_kernel": 4,
22+
"eos_token_id": 2,
23+
"expand": 2,
24+
"fuse_cross_entropy": true,
25+
"fuse_norm": true,
26+
"fuse_swiglu": true,
27+
"hidden_act": "swish",
28+
"hidden_ratio": 4,
29+
"hidden_size": 2304,
30+
"initializer_range": 0.02,
31+
"intermediate_size": 4608,
32+
"max_position_embeddings": 2048,
33+
"model_type": "samba",
34+
"norm_eps": 1e-05,
35+
"num_hidden_layers": 18,
36+
"pad_token_id": 0,
37+
"rescale_prenorm_residual": false,
38+
"residual_in_fp32": false,
39+
"state_size": 16,
40+
"tie_word_embeddings": false,
41+
"time_step_floor": 0.0001,
42+
"time_step_init_scheme": "random",
43+
"time_step_max": 0.1,
44+
"time_step_min": 0.001,
45+
"time_step_rank": 144,
46+
"time_step_scale": 1.0,
47+
"transformers_version": "4.50.1",
48+
"use_bias": false,
49+
"use_cache": true,
50+
"use_conv_bias": true,
51+
"vocab_size": 32000
52+
}

flame/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import fla # noqa
1818
from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
19-
from fla.ops.common.utils import prepare_position_ids
19+
from fla.ops.utils import prepare_position_ids
2020
from flame.components.checkpoint import TrainState
2121
from flame.config_manager import JobConfig
2222
from flame.data import build_dataloader, shuffle

0 commit comments

Comments
 (0)