Skip to content

Commit 37ee087

Browse files
navsudfacebook-github-bot
authored andcommitted
Update rope to support QAT on GPU (#14619)
Summary: As part of enabling QAT for HTP model, we need to run QAT on the model that we use during export. Currently Rope is explicitly hardcoded to "cpu". This change enables us to create rope params on "cuda" if it is run on GPU machine. Differential Revision: D82239525
1 parent d4f208d commit 37ee087

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/models/llama/rope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import math
1111
from functools import partial
12-
from typing import Optional, Tuple, Union
12+
from typing import Optional, Tuple
1313

1414
import torch
1515
from executorch.examples.models.llama.model_args import ModelArgs
@@ -47,8 +47,8 @@ def precompute_freqs_cis(
4747
use_scaled: bool = False,
4848
scale_factor: Optional[int] = None,
4949
high_freq_factor: int = 4,
50-
device: Union[str, torch.device] = "cpu",
5150
):
51+
device = "cuda" if torch.cuda.is_available() else "cpu"
5252
freqs = 1.0 / (
5353
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
5454
)

0 commit comments

Comments
 (0)