Skip to content

Commit ec56c95

Browse files
authored
Enable QAT for static llama definition
Differential Revision: D79841467 Pull Request resolved: #13285
1 parent 2707ce3 commit ec56c95

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/models/llama/rope.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import math
1111
from functools import partial
12-
from typing import Optional, Tuple
12+
from typing import Optional, Tuple, Union
1313

1414
import torch
1515
from executorch.examples.models.llama.model_args import ModelArgs
@@ -47,9 +47,10 @@ def precompute_freqs_cis(
4747
use_scaled: bool = False,
4848
scale_factor: Optional[int] = None,
4949
high_freq_factor: int = 4,
50+
device: Union[str, torch.device] = "cpu",
5051
):
5152
freqs = 1.0 / (
52-
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
53+
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
5354
)
5455
t = torch.arange(end, device=freqs.device) # pyre-ignore
5556
if use_scaled:

0 commit comments

Comments
 (0)