3
3
from copy import deepcopy
4
4
from dataclasses import dataclass , field
5
5
from pathlib import Path
6
- from typing import Any , Literal , Optional , Type , Union
6
+ from typing import Any , List , Literal , Optional , Type , Union
7
7
8
8
import torch
9
9
import yaml
@@ -58,9 +58,7 @@ class Config:
58
58
attn_bias : bool = False
59
59
attention_scores_scalar : Optional [int ] = None
60
60
sliding_window_size : Optional [int ] = None
61
- sliding_window_layer_placing : Optional [Literal ["all" , "interleaved" ]] = None
62
- sliding_window_layer_stride : Optional [int ] = None
63
- sliding_window_offset : int = 0
61
+ sliding_window_indices : Optional [List ] = None
64
62
# if `attention_logit_softcapping` is used, cannot use optimized
65
63
# `torch.nn.functional.scaled_dot_product_attention` (which implements
66
64
# Flash attention), may result in higher memory and runtime footprint.
@@ -114,14 +112,11 @@ def __post_init__(self):
114
112
115
113
self .rope_n_elem = int (self .rotary_percentage * self .head_size )
116
114
117
- if self .sliding_window_size is not None :
118
- self .sliding_window_layer_stride = (
119
- (1 if (self .sliding_window_layer_placing is None or self .sliding_window_layer_placing == "all" ) else 2 )
120
- if self .sliding_window_layer_stride is None
121
- else self .sliding_window_layer_stride
122
- )
115
+ if self .sliding_window_size is not None and self .sliding_window_indices is None :
116
+ self .sliding_window_indices = [1 ] * self .n_layer
123
117
124
- self .sliding_window_block_idx_map_fn = lambda x : x + self .sliding_window_offset
118
+ if self .rope_local_base_freq is not None and self .rope_indices is None :
119
+ self .rope_indices = [1 ] * self .n_layer
125
120
126
121
@classmethod
127
122
def from_name (cls , name : str , ** kwargs : Any ) -> Optional [Self ]:
@@ -974,7 +969,7 @@ def norm_class(self) -> Type:
974
969
block_size = 8192 ,
975
970
sliding_window_size = 4096 ,
976
971
# only layer with idx 0, 2, 4, ... have sliding window attention
977
- sliding_window_layer_placing = "interleaved" ,
972
+ sliding_window_indices = [ 1 if i % 2 == 0 else 0 for i in range ( 26 )] ,
978
973
intermediate_size = 9216 ,
979
974
n_embd = 2304 ,
980
975
n_layer = 26 ,
@@ -1002,7 +997,7 @@ def norm_class(self) -> Type:
1002
997
block_size = 8192 ,
1003
998
sliding_window_size = 4096 ,
1004
999
# only layer with idx 0, 2, 4, ... have sliding window attention
1005
- sliding_window_layer_placing = "interleaved" ,
1000
+ sliding_window_indices = [ 1 if i % 2 == 0 else 0 for i in range ( 42 )] ,
1006
1001
intermediate_size = 14336 ,
1007
1002
n_embd = 3584 ,
1008
1003
n_layer = 42 ,
@@ -1032,7 +1027,7 @@ def norm_class(self) -> Type:
1032
1027
block_size = 8192 ,
1033
1028
sliding_window_size = 4096 ,
1034
1029
# only layer with idx 0, 2, 4, ... have sliding window attention
1035
- sliding_window_layer_placing = "interleaved" ,
1030
+ sliding_window_indices = [ 1 if i % 2 == 0 else 0 for i in range ( 46 )] ,
1036
1031
intermediate_size = 36864 ,
1037
1032
n_embd = 4608 ,
1038
1033
n_layer = 46 ,
@@ -1549,7 +1544,6 @@ def norm_class(self) -> Type:
1549
1544
mlp_class_name = "LLaMAMLP" ,
1550
1545
parallel_residual = False ,
1551
1546
sliding_window_size = 2048 ,
1552
- sliding_window_layer_placing = "all" ,
1553
1547
),
1554
1548
# https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/main/config.json
1555
1549
dict (
@@ -1567,7 +1561,6 @@ def norm_class(self) -> Type:
1567
1561
mlp_class_name = "LLaMAMLP" ,
1568
1562
parallel_residual = False ,
1569
1563
sliding_window_size = 262145 ,
1570
- sliding_window_layer_placing = "all" ,
1571
1564
),
1572
1565
# https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json
1573
1566
dict (
@@ -1622,7 +1615,6 @@ def norm_class(self) -> Type:
1622
1615
mlp_class_name = "LLaMAMLP" ,
1623
1616
parallel_residual = False ,
1624
1617
sliding_window_size = 262145 ,
1625
- sliding_window_layer_placing = "all" ,
1626
1618
),
1627
1619
]
1628
1620
configs .extend (phi )
@@ -1649,7 +1641,6 @@ def norm_class(self) -> Type:
1649
1641
mlp_class_name = "LLaMAMLP" ,
1650
1642
intermediate_size = 14336 ,
1651
1643
sliding_window_size = 4096 ,
1652
- sliding_window_layer_placing = "all" ,
1653
1644
)
1654
1645
)
1655
1646
@@ -1670,7 +1661,6 @@ def norm_class(self) -> Type:
1670
1661
mlp_class_name = "LLaMAMLP" ,
1671
1662
intermediate_size = 14336 ,
1672
1663
sliding_window_size = 4096 ,
1673
- sliding_window_layer_placing = "all" ,
1674
1664
),
1675
1665
# https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
1676
1666
dict (
0 commit comments