We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c609f63 commit d36bf8cCopy full SHA for d36bf8c
examples/models/llama/rope.py
@@ -240,7 +240,7 @@ def __init__(self, params: ModelArgs):
240
self.precompute_freqs_cis = partial(
241
hf_precompute_freqs_cis,
242
partial_rotary_factor=self.params.partial_rotary_factor,
243
- device=self.params.device,
+ device=getattr(self.params, "device", "cpu"),
244
)
245
self.apply_rotary_emb = hf_apply_rotary_emb
246
else:
@@ -249,7 +249,7 @@ def __init__(self, params: ModelArgs):
249
use_scaled=self.params.use_scaled_rope,
250
scale_factor=self.params.rope_scale_factor,
251
high_freq_factor=self.params.high_freq_factor,
252
253
254
self.apply_rotary_emb = RotaryEmbedding()
255
0 commit comments