diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index f788b8f5032..8c0d5db6a80 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -9,7 +9,7 @@ import math from functools import partial -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from executorch.examples.models.llama.model_args import ModelArgs @@ -47,9 +47,10 @@ def precompute_freqs_cis( use_scaled: bool = False, scale_factor: Optional[int] = None, high_freq_factor: int = 4, + device: Union[str, torch.device] = "cpu", ): freqs = 1.0 / ( - theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) + theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) ) t = torch.arange(end, device=freqs.device) # pyre-ignore if use_scaled: