Skip to content

Commit 4f3d12e

Browse files
add no_rope_layer_interval into config
1 parent d38eb08 commit 4f3d12e

File tree

4 files changed

+6
-9
lines changed

4 files changed

+6
-9
lines changed

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class ModelArgs:
7878
use_qk_norm: bool = False # apply normalization to q and k in the attention
7979
qk_norm_before_rope: bool = False # when to apply qk norm
8080
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
81+
no_rope_layer_interval: Optional[int] = (
82+
None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795).
83+
)
8184
partial_rotary_factor: float = 1.0
8285
rope_theta: Optional[float] = (
8386
None # The official name to override self.rope_freq_base.

examples/models/smollm3/3b_config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
"use_scaled_rope": false,
1111
"vocab_size": 128256,
1212
"use_hf_rope": false,
13+
"no_rope_layer_interval": 4,
1314
"attention_qkv_bias": false
1415
}

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -445,13 +445,6 @@ def compile(
445445
else:
446446
kv_config.enable_masked_softmax = True
447447

448-
if args.decoder_model == "smollm3-3b":
449-
from transformers import AutoConfig
450-
451-
kv_config.apply_rope_layers = AutoConfig.from_pretrained(
452-
decoder_model_config.repo_id
453-
).no_rope_layers
454-
455448
prefill_config = copy.copy(kv_config)
456449
prefill_config.use_kv_cache = (
457450
False if args.max_seq_len == args.prefill_ar_len else True

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals
7575
self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False)
7676
self.use_qk_norm = config.use_qk_norm
7777
self.qk_norm_before_rope = config.qk_norm_before_rope
78-
apply_rope_layers = getattr(config, "apply_rope_layers", None)
7978
self.use_rope = (
80-
apply_rope_layers[layer_idx] if apply_rope_layers is not None else True
79+
config.no_rope_layer_interval
80+
and (layer_idx + 1) % config.no_rope_layer_interval
8181
)
8282

8383
if self.use_qk_norm:

0 commit comments

Comments
 (0)