Skip to content

Commit 52ddda3

Browse files
Eviannnevian
authored andcommitted
[Minor][Models] Pass partial_rotary_factor parameter to rope (vllm-project#17266)
Signed-off-by: evian <[email protected]> Co-authored-by: evian <[email protected]> Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent 3dcb749 commit 52ddda3

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

vllm/model_executor/models/llama.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def __init__(self,
130130
self.head_dim = getattr(config, "head_dim",
131131
self.hidden_size // self.total_num_heads)
132132
# Phi models introduced a partial_rotary_factor parameter in the config
133-
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
134-
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
133+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
134+
1)
135135
self.q_size = self.num_heads * self.head_dim
136136
self.kv_size = self.num_kv_heads * self.head_dim
137137
self.scaling = self.head_dim**-0.5
@@ -163,11 +163,12 @@ def __init__(self,
163163

164164
self.rotary_emb = get_rope(
165165
self.head_dim,
166-
rotary_dim=self.rotary_dim,
166+
rotary_dim=self.head_dim,
167167
max_position=max_position_embeddings,
168168
base=rope_theta,
169169
rope_scaling=rope_scaling,
170170
is_neox_style=is_neox_style,
171+
partial_rotary_factor=self.partial_rotary_factor,
171172
)
172173

173174
if hasattr(config, "interleaved_sliding_window"):

vllm/model_executor/models/persimmon.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,10 @@ def __init__(self,
115115

116116
self.rotary_emb = get_rope(
117117
self.head_dim,
118-
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
118+
rotary_dim=self.head_dim,
119119
max_position=self.max_position_embeddings,
120120
base=self.rope_theta,
121+
partial_rotary_factor=self.partial_rotary_factor,
121122
)
122123
self.scaling = self.head_dim**-0.5
123124
self.attn = Attention(self.num_heads,

vllm/model_executor/models/stablelm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,8 @@ def __init__(self,
104104
1, self.total_num_key_value_heads // tp_size)
105105
self.head_dim = self.hidden_size // self.total_num_heads
106106
self.max_position_embeddings = config.max_position_embeddings
107-
rope_pct = getattr(config, "rope_pct",
108-
getattr(config, "partial_rotary_factor", 1))
109-
self.rotary_ndims = int(self.head_dim * rope_pct)
107+
self.partial_rotary_factor = getattr(
108+
config, "rope_pct", getattr(config, "partial_rotary_factor", 1))
110109
self.scaling = self.head_dim**-0.5
111110
self.q_size = self.num_heads * self.head_dim
112111
self.kv_size = self.num_key_value_heads * self.head_dim
@@ -130,9 +129,10 @@ def __init__(self,
130129
prefix=f"{prefix}.o_proj")
131130
self.rotary_emb = get_rope(
132131
self.head_dim,
133-
rotary_dim=self.rotary_ndims,
132+
rotary_dim=self.head_dim,
134133
max_position=self.config.max_position_embeddings,
135134
base=self.config.rope_theta,
135+
partial_rotary_factor=self.partial_rotary_factor,
136136
)
137137
self.attn = Attention(self.num_heads,
138138
self.head_dim,

0 commit comments

Comments
 (0)