Skip to content

Commit d36bf8c

Browse files
authored
Run ET-eager on message recall
Differential Revision: D83990682 Pull Request resolved: #14822
1 parent c609f63 commit d36bf8c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/models/llama/rope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def __init__(self, params: ModelArgs):
240240
self.precompute_freqs_cis = partial(
241241
hf_precompute_freqs_cis,
242242
partial_rotary_factor=self.params.partial_rotary_factor,
243-
device=self.params.device,
243+
device=getattr(self.params, "device", "cpu"),
244244
)
245245
self.apply_rotary_emb = hf_apply_rotary_emb
246246
else:
@@ -249,7 +249,7 @@ def __init__(self, params: ModelArgs):
249249
use_scaled=self.params.use_scaled_rope,
250250
scale_factor=self.params.rope_scale_factor,
251251
high_freq_factor=self.params.high_freq_factor,
252-
device=self.params.device,
252+
device=getattr(self.params, "device", "cpu"),
253253
)
254254
self.apply_rotary_emb = RotaryEmbedding()
255255

0 commit comments

Comments
 (0)