Skip to content

Commit c997fe4

Browse files
authored
Remove explicit device arguments
Differential Revision: D82239525 Pull Request resolved: #14619
1 parent 53ccfd0 commit c997fe4

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class ModelArgs:
6363
use_sdpa_with_kv_cache_op: bool = (
6464
False # Use custom sdpa op that updates kv cache in-place
6565
)
66+
# Device to use for the model: "cpu" or "cuda" (needed for QAT)
67+
# Only used for creating Rope parameters
68+
device: str = "cpu"
6669
# Generate logits for all inputs. When it's True, it would take big memory usage
6770
# at runtime. Enable it only necessary (e.g., use perplexity tools that requires
6871
# logits for all input tokens.)

examples/models/llama/rope.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,19 @@ def forward(
138138
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
139139
# Current only support non-long rope.
140140
def hf_precompute_freqs_cis(
141-
dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0
141+
dim: int,
142+
end: int,
143+
theta: float,
144+
partial_rotary_factor: float = 1.0,
145+
device: Union[str, torch.device] = "cpu",
142146
):
143147
# Partial rotary embeddings.
144148
dim = int(dim * partial_rotary_factor)
145149

146150
# Short factor scaling.
147151
freqs = 1.0 / (
148152
theta
149-
** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
153+
** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)
150154
)
151155
# TODO: support long factor scaling.
152156

@@ -236,6 +240,7 @@ def __init__(self, params: ModelArgs):
236240
self.precompute_freqs_cis = partial(
237241
hf_precompute_freqs_cis,
238242
partial_rotary_factor=self.params.partial_rotary_factor,
243+
device=self.params.device,
239244
)
240245
self.apply_rotary_emb = hf_apply_rotary_emb
241246
else:
@@ -244,6 +249,7 @@ def __init__(self, params: ModelArgs):
244249
use_scaled=self.params.use_scaled_rope,
245250
scale_factor=self.params.rope_scale_factor,
246251
high_freq_factor=self.params.high_freq_factor,
252+
device=self.params.device,
247253
)
248254
self.apply_rotary_emb = RotaryEmbedding()
249255

0 commit comments

Comments
 (0)