@@ -138,15 +138,19 @@ def forward(
138
138
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
139
139
# Current only support non-long rope.
140
140
def hf_precompute_freqs_cis (
141
- dim : int , end : int , theta : float , partial_rotary_factor : float = 1.0
141
+ dim : int ,
142
+ end : int ,
143
+ theta : float ,
144
+ partial_rotary_factor : float = 1.0 ,
145
+ device : Union [str , torch .device ] = "cpu" ,
142
146
):
143
147
# Partial rotary embeddings.
144
148
dim = int (dim * partial_rotary_factor )
145
149
146
150
# Short factor scaling.
147
151
freqs = 1.0 / (
148
152
theta
149
- ** (torch .arange (0 , dim , 2 , device = "cpu" , dtype = torch .int64 ).float () / dim )
153
+ ** (torch .arange (0 , dim , 2 , device = device , dtype = torch .int64 ).float () / dim )
150
154
)
151
155
# TODO: support long factor scaling.
152
156
@@ -236,6 +240,7 @@ def __init__(self, params: ModelArgs):
236
240
self .precompute_freqs_cis = partial (
237
241
hf_precompute_freqs_cis ,
238
242
partial_rotary_factor = self .params .partial_rotary_factor ,
243
+ device = self .params .device ,
239
244
)
240
245
self .apply_rotary_emb = hf_apply_rotary_emb
241
246
else :
@@ -244,6 +249,7 @@ def __init__(self, params: ModelArgs):
244
249
use_scaled = self .params .use_scaled_rope ,
245
250
scale_factor = self .params .rope_scale_factor ,
246
251
high_freq_factor = self .params .high_freq_factor ,
252
+ device = self .params .device ,
247
253
)
248
254
self .apply_rotary_emb = RotaryEmbedding ()
249
255
0 commit comments