|
3 | 3 | from copy import deepcopy
|
4 | 4 | from dataclasses import dataclass, field
|
5 | 5 | from pathlib import Path
|
6 |
| -from typing import Any, Callable, Literal, Optional, Type, Union |
| 6 | +from typing import Any, Literal, Optional, Type, Union |
7 | 7 |
|
8 | 8 | import torch
|
9 | 9 | import yaml
|
@@ -60,7 +60,7 @@ class Config:
|
60 | 60 | sliding_window_size: Optional[int] = None
|
61 | 61 | sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None
|
62 | 62 | sliding_window_layer_stride: Optional[int] = None
|
63 |
| - sliding_window_type: Optional[Literal["gemma3"]] = None |
| 63 | + sliding_window_offset: int = 0 |
64 | 64 | # if `attention_logit_softcapping` is used, cannot use optimized
|
65 | 65 | # `torch.nn.functional.scaled_dot_product_attention` (which implements
|
66 | 66 | # Flash attention), may result in higher memory and runtime footprint.
|
@@ -118,10 +118,7 @@ def __post_init__(self):
|
118 | 118 | else self.sliding_window_layer_stride
|
119 | 119 | )
|
120 | 120 |
|
121 |
| - SLIDING_WINDOW_TYPE_TO_MAP_FN: dict[Literal["gemma3"], Callable[[int], int]] = {"gemma3": lambda x: x + 1} |
122 |
| - self.sliding_window_block_idx_map_fn = ( |
123 |
| - lambda x: x if self.sliding_window_type is None else SLIDING_WINDOW_TYPE_TO_MAP_FN[self.sliding_window_type] |
124 |
| - ) |
| 121 | + self.sliding_window_block_idx_map_fn = lambda x: x + self.sliding_window_offset |
125 | 122 |
|
126 | 123 | @classmethod
|
127 | 124 | def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
|
|
0 commit comments