Skip to content

Commit 322bd20

Browse files
feat: replace sliding window type with offset (#1989)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8ed90f9 commit 322bd20

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

litgpt/config.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from copy import deepcopy
44
from dataclasses import dataclass, field
55
from pathlib import Path
6-
from typing import Any, Callable, Literal, Optional, Type, Union
6+
from typing import Any, Literal, Optional, Type, Union
77

88
import torch
99
import yaml
@@ -60,7 +60,7 @@ class Config:
6060
sliding_window_size: Optional[int] = None
6161
sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None
6262
sliding_window_layer_stride: Optional[int] = None
63-
sliding_window_type: Optional[Literal["gemma3"]] = None
63+
sliding_window_offset: int = 0
6464
# if `attention_logit_softcapping` is used, cannot use optimized
6565
# `torch.nn.functional.scaled_dot_product_attention` (which implements
6666
# Flash attention), may result in higher memory and runtime footprint.
@@ -118,10 +118,7 @@ def __post_init__(self):
118118
else self.sliding_window_layer_stride
119119
)
120120

121-
SLIDING_WINDOW_TYPE_TO_MAP_FN: dict[Literal["gemma3"], Callable[[int], int]] = {"gemma3": lambda x: x + 1}
122-
self.sliding_window_block_idx_map_fn = (
123-
lambda x: x if self.sliding_window_type is None else SLIDING_WINDOW_TYPE_TO_MAP_FN[self.sliding_window_type]
124-
)
121+
self.sliding_window_block_idx_map_fn = lambda x: x + self.sliding_window_offset
125122

126123
@classmethod
127124
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:

0 commit comments

Comments
 (0)