Skip to content

Commit ebfbac6

Browse files
authored
Add sliding window and optimize redundant code (#2655)
1 parent 6a08720 commit ebfbac6

File tree

13 files changed

+1138
-453
lines changed

13 files changed

+1138
-453
lines changed

paddleformers/transformers/configuration_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,3 +1269,20 @@ def get_configuration_file(configuration_files: List[str]) -> str:
12691269
break
12701270

12711271
return configuration_file
1272+
1273+
1274+
ALLOWED_LAYER_TYPES = (
1275+
"full_attention",
1276+
"sliding_attention",
1277+
)
1278+
1279+
1280+
def layer_type_validation(layer_types: List[str], num_hidden_layers: Optional[int] = None):
1281+
"""Check that `layer_types` is correctly defined."""
1282+
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
1283+
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
1284+
if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
1285+
raise ValueError(
1286+
f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types "
1287+
f"({len(layer_types)})"
1288+
)

paddleformers/transformers/qwen2/configuration.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""Qwen2 model configuration"""
1616

17-
from ..configuration_utils import PretrainedConfig
17+
from ..configuration_utils import PretrainedConfig, layer_type_validation
1818

1919
__all__ = [
2020
"Qwen2Config",
@@ -129,6 +129,7 @@ def __init__(
129129
attention_dropout=0.0,
130130
rope_scaling_factor=1.0,
131131
rope_scaling_type=None,
132+
layer_types=None,
132133
pp_seg_method="layer:Qwen2DecoderLayer",
133134
**kwargs,
134135
):
@@ -167,6 +168,14 @@ def __init__(
167168

168169
self.pp_seg_method = pp_seg_method
169170

171+
self.layer_types = layer_types
172+
if self.layer_types is None:
173+
self.layer_types = [
174+
"sliding_attention" if self.use_sliding_window and i >= self.max_window_layers else "full_attention"
175+
for i in range(self.num_hidden_layers)
176+
]
177+
layer_type_validation(self.layer_types, self.num_hidden_layers)
178+
170179
super().__init__(
171180
pad_token_id=pad_token_id,
172181
bos_token_id=bos_token_id,
@@ -190,5 +199,6 @@ def __init__(
190199
"pp_seg_method",
191200
"dpo_config",
192201
"kto_config",
202+
"layer_types",
193203
]
194204
)

0 commit comments

Comments
 (0)