Skip to content

Commit 6bae1ae

Browse files
navsudfacebook-github-bot
authored andcommitted
Enable QAT for static llama definition (pytorch#13285)
Summary: The model needed small modifications to be able to run QAT on GPUs. Reviewed By: YIWENX14 Differential Revision: D79841467
1 parent 41fdf13 commit 6bae1ae

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/models/llama/rope.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def precompute_freqs_cis(
4747
use_scaled: bool = False,
4848
scale_factor: Optional[int] = None,
4949
high_freq_factor: int = 4,
50+
device: torch.device = torch.device("cpu"),
5051
):
5152
freqs = 1.0 / (
52-
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
53+
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
5354
)
5455
t = torch.arange(end, device=freqs.device) # pyre-ignore
5556
if use_scaled:

0 commit comments

Comments
 (0)