@@ -138,15 +138,19 @@ def forward(
138138# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
139139# Current only support non-long rope.
140140def hf_precompute_freqs_cis (
141- dim : int , end : int , theta : float , partial_rotary_factor : float = 1.0
141+ dim : int ,
142+ end : int ,
143+ theta : float ,
144+ partial_rotary_factor : float = 1.0 ,
145+ device : Union [str , torch .device ] = "cpu" ,
142146):
143147 # Partial rotary embeddings.
144148 dim = int (dim * partial_rotary_factor )
145149
146150 # Short factor scaling.
147151 freqs = 1.0 / (
148152 theta
149- ** (torch .arange (0 , dim , 2 , device = "cpu" , dtype = torch .int64 ).float () / dim )
153+ ** (torch .arange (0 , dim , 2 , device = device , dtype = torch .int64 ).float () / dim )
150154 )
151155 # TODO: support long factor scaling.
152156
@@ -236,6 +240,7 @@ def __init__(self, params: ModelArgs):
236240 self .precompute_freqs_cis = partial (
237241 hf_precompute_freqs_cis ,
238242 partial_rotary_factor = self .params .partial_rotary_factor ,
243+ device = self .params .device ,
239244 )
240245 self .apply_rotary_emb = hf_apply_rotary_emb
241246 else :
@@ -244,6 +249,7 @@ def __init__(self, params: ModelArgs):
244249 use_scaled = self .params .use_scaled_rope ,
245250 scale_factor = self .params .rope_scale_factor ,
246251 high_freq_factor = self .params .high_freq_factor ,
252+ device = self .params .device ,
247253 )
248254 self .apply_rotary_emb = RotaryEmbedding ()
249255
0 commit comments