Skip to content

Commit b04778f

Browse files
navsudfacebook-github-bot
authored andcommitted
Update rope to support QAT on GPU (#14619)
Summary: Pull Request resolved: #14619 As part of enabling QAT for HTP model, we need to run QAT on the model that we use during export. Currently Rope is explicitly hardcoded to "cpu". This change enables us to create rope params on "cuda" if it is run on GPU machine. Differential Revision: D82239525
1 parent d4f208d commit b04778f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/models/llama/rope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def precompute_freqs_cis(
4747
use_scaled: bool = False,
4848
scale_factor: Optional[int] = None,
4949
high_freq_factor: int = 4,
50-
device: Union[str, torch.device] = "cpu",
5150
):
51+
device = "cuda" if torch.cuda.is_available() else "cpu"
5252
freqs = 1.0 / (
5353
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
5454
)

0 commit comments

Comments
 (0)