We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2707ce3 commit ec56c95Copy full SHA for ec56c95
examples/models/llama/rope.py
@@ -9,7 +9,7 @@
9
10
import math
11
from functools import partial
12
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
13
14
import torch
15
from executorch.examples.models.llama.model_args import ModelArgs
@@ -47,9 +47,10 @@ def precompute_freqs_cis(
47
use_scaled: bool = False,
48
scale_factor: Optional[int] = None,
49
high_freq_factor: int = 4,
50
+ device: Union[str, torch.device] = "cpu",
51
):
52
freqs = 1.0 / (
- theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
53
+ theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
54
)
55
t = torch.arange(end, device=freqs.device) # pyre-ignore
56
if use_scaled:
0 commit comments