We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 41fdf13 commit c6f06abCopy full SHA for c6f06ab
examples/models/llama/rope.py
@@ -47,9 +47,10 @@ def precompute_freqs_cis(
47
use_scaled: bool = False,
48
scale_factor: Optional[int] = None,
49
high_freq_factor: int = 4,
50
+ device: torch.device = torch.device("cpu"),
51
):
52
freqs = 1.0 / (
- theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
53
+ theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
54
)
55
t = torch.arange(end, device=freqs.device) # pyre-ignore
56
if use_scaled:
0 commit comments