Skip to content

Commit 33381ea

Browse files
navsudfacebook-github-bot
authored andcommitted
Update rope to support QAT on GPU (pytorch#14619)
Summary: As part of enabling QAT for HTP model, we need to run QAT on the model that we use during export. Currently Rope is explicitly hardcoded to "cpu". This change enables us to switch between "cpu" vs. "cuda" based on the usecase. Reviewed By: billmguo Differential Revision: D82239525
1 parent 53ccfd0 commit 33381ea

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class ModelArgs:
6363
use_sdpa_with_kv_cache_op: bool = (
6464
False # Use custom sdpa op that updates kv cache in-place
6565
)
66+
# Device to use for the model: "cpu" or "cuda" (needed for QAT)
67+
# Only used for creating Rope parameters
68+
device: str = "cpu"
6669
# Generate logits for all inputs. When it's True, it would take big memory usage
6770
# at runtime. Enable it only necessary (e.g., use perplexity tools that requires
6871
# logits for all input tokens.)

examples/models/llama/rope.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,19 @@ def forward(
138138
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
139139
# Current only support non-long rope.
140140
def hf_precompute_freqs_cis(
141-
dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0
141+
dim: int,
142+
end: int,
143+
theta: float,
144+
partial_rotary_factor: float = 1.0,
145+
device: Union[str, torch.device] = "cpu",
142146
):
143147
# Partial rotary embeddings.
144148
dim = int(dim * partial_rotary_factor)
145149

146150
# Short factor scaling.
147151
freqs = 1.0 / (
148152
theta
149-
** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
153+
** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)
150154
)
151155
# TODO: support long factor scaling.
152156

@@ -236,6 +240,7 @@ def __init__(self, params: ModelArgs):
236240
self.precompute_freqs_cis = partial(
237241
hf_precompute_freqs_cis,
238242
partial_rotary_factor=self.params.partial_rotary_factor,
243+
device=self.params.device,
239244
)
240245
self.apply_rotary_emb = hf_apply_rotary_emb
241246
else:
@@ -244,6 +249,7 @@ def __init__(self, params: ModelArgs):
244249
use_scaled=self.params.use_scaled_rope,
245250
scale_factor=self.params.rope_scale_factor,
246251
high_freq_factor=self.params.high_freq_factor,
252+
device=self.params.device,
247253
)
248254
self.apply_rotary_emb = RotaryEmbedding()
249255

0 commit comments

Comments
 (0)