Skip to content

Commit 7789e82

Browse files
authored
fix: replace sliding window configuration parameters to sliding windows indices (#1995)
1 parent 333bb23 commit 7789e82

File tree

3 files changed

+12
-24
lines changed

3 files changed

+12
-24
lines changed

litgpt/config.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from copy import deepcopy
44
from dataclasses import dataclass, field
55
from pathlib import Path
6-
from typing import Any, Literal, Optional, Type, Union
6+
from typing import Any, List, Literal, Optional, Type, Union
77

88
import torch
99
import yaml
@@ -58,9 +58,7 @@ class Config:
5858
attn_bias: bool = False
5959
attention_scores_scalar: Optional[int] = None
6060
sliding_window_size: Optional[int] = None
61-
sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None
62-
sliding_window_layer_stride: Optional[int] = None
63-
sliding_window_offset: int = 0
61+
sliding_window_indices: Optional[List] = None
6462
# if `attention_logit_softcapping` is used, cannot use optimized
6563
# `torch.nn.functional.scaled_dot_product_attention` (which implements
6664
# Flash attention), may result in higher memory and runtime footprint.
@@ -114,14 +112,11 @@ def __post_init__(self):
114112

115113
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
116114

117-
if self.sliding_window_size is not None:
118-
self.sliding_window_layer_stride = (
119-
(1 if (self.sliding_window_layer_placing is None or self.sliding_window_layer_placing == "all") else 2)
120-
if self.sliding_window_layer_stride is None
121-
else self.sliding_window_layer_stride
122-
)
115+
if self.sliding_window_size is not None and self.sliding_window_indices is None:
116+
self.sliding_window_indices = [1] * self.n_layer
123117

124-
self.sliding_window_block_idx_map_fn = lambda x: x + self.sliding_window_offset
118+
if self.rope_local_base_freq is not None and self.rope_indices is None:
119+
self.rope_indices = [1] * self.n_layer
125120

126121
@classmethod
127122
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
@@ -974,7 +969,7 @@ def norm_class(self) -> Type:
974969
block_size=8192,
975970
sliding_window_size=4096,
976971
# only layer with idx 0, 2, 4, ... have sliding window attention
977-
sliding_window_layer_placing="interleaved",
972+
sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(26)],
978973
intermediate_size=9216,
979974
n_embd=2304,
980975
n_layer=26,
@@ -1002,7 +997,7 @@ def norm_class(self) -> Type:
1002997
block_size=8192,
1003998
sliding_window_size=4096,
1004999
# only layer with idx 0, 2, 4, ... have sliding window attention
1005-
sliding_window_layer_placing="interleaved",
1000+
sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(42)],
10061001
intermediate_size=14336,
10071002
n_embd=3584,
10081003
n_layer=42,
@@ -1032,7 +1027,7 @@ def norm_class(self) -> Type:
10321027
block_size=8192,
10331028
sliding_window_size=4096,
10341029
# only layer with idx 0, 2, 4, ... have sliding window attention
1035-
sliding_window_layer_placing="interleaved",
1030+
sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(46)],
10361031
intermediate_size=36864,
10371032
n_embd=4608,
10381033
n_layer=46,
@@ -1549,7 +1544,6 @@ def norm_class(self) -> Type:
15491544
mlp_class_name="LLaMAMLP",
15501545
parallel_residual=False,
15511546
sliding_window_size=2048,
1552-
sliding_window_layer_placing="all",
15531547
),
15541548
# https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/main/config.json
15551549
dict(
@@ -1567,7 +1561,6 @@ def norm_class(self) -> Type:
15671561
mlp_class_name="LLaMAMLP",
15681562
parallel_residual=False,
15691563
sliding_window_size=262145,
1570-
sliding_window_layer_placing="all",
15711564
),
15721565
# https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json
15731566
dict(
@@ -1622,7 +1615,6 @@ def norm_class(self) -> Type:
16221615
mlp_class_name="LLaMAMLP",
16231616
parallel_residual=False,
16241617
sliding_window_size=262145,
1625-
sliding_window_layer_placing="all",
16261618
),
16271619
]
16281620
configs.extend(phi)
@@ -1649,7 +1641,6 @@ def norm_class(self) -> Type:
16491641
mlp_class_name="LLaMAMLP",
16501642
intermediate_size=14336,
16511643
sliding_window_size=4096,
1652-
sliding_window_layer_placing="all",
16531644
)
16541645
)
16551646

@@ -1670,7 +1661,6 @@ def norm_class(self) -> Type:
16701661
mlp_class_name="LLaMAMLP",
16711662
intermediate_size=14336,
16721663
sliding_window_size=4096,
1673-
sliding_window_layer_placing="all",
16741664
),
16751665
# https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
16761666
dict(

litgpt/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,9 @@ def __init__(self, config: Config, block_idx: int) -> None:
324324
self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
325325
# disabled by default
326326
self.kv_cache: Optional[KVCache] = None
327-
self.apply_sliding_window_attention = (
328-
config.sliding_window_size is not None
329-
and config.sliding_window_block_idx_map_fn(block_idx) % config.sliding_window_layer_stride == 0
330-
)
327+
self.apply_sliding_window_attention = False
328+
if config.sliding_window_size is not None and config.sliding_window_indices is not None:
329+
self.apply_sliding_window_attention = config.sliding_window_indices[block_idx]
331330

332331
if config.norm_qk:
333332
self.norm_q = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps)

tests/test_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,6 @@ def test_against_mistral_hf_models(device, dtype, model_name):
420420
padded_vocab_size=10000,
421421
block_size=T,
422422
sliding_window_size=T // 2,
423-
sliding_window_layer_placing="all",
424423
n_layer=2,
425424
n_embd=32,
426425
n_head=8,

0 commit comments

Comments
 (0)