Skip to content

Commit 9c5ccc2

Browse files
navsudfacebook-github-bot
authored andcommitted
Remove explicit device arguments
Summary: As part of enabling QAT for HTP model, we need to run QAT on the model that we use during export. For that, having hardcoded device type to "cpu" needs a lot of model changes to move them to "cuda". Simpler solution is to remove the device type and let the device type be auto-inferred during export. For training time, we anyway build the model with the context `with torch.device("cuda"):` which takes care of it. Update: This was failing multiple export tests, as the Llama2Model (at llama/model.py) was instantiating the transformer with "meta" device, which needed the rope params to be explicitly instantiated on "cpu" device. Changed "meta" to "cpu" to fix this issue. Differential Revision: D82239525
1 parent 2283294 commit 9c5ccc2

File tree

3 files changed

+3
-5
lines changed

3 files changed

+3
-5
lines changed

examples/models/llama/attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ def __init__(
429429
self.max_context_len,
430430
self.max_context_len,
431431
dtype=torch.bool,
432-
device="cpu",
433432
)
434433
)
435434
self.register_buffer("mask", causal_mask, persistent=False)

examples/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
195195

196196
# Within the device="meta" context, tensors that are created do not carry data.
197197
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
198-
with torch.device("meta"):
198+
with torch.device("cpu"):
199199
# Model itself is loaded in default dtype, fp32.
200200
self.model_ = construct_transformer(model_args)
201201
# Get checkpoint dtype.

examples/models/llama/rope.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,11 @@ def precompute_freqs_cis(
4747
use_scaled: bool = False,
4848
scale_factor: Optional[int] = None,
4949
high_freq_factor: int = 4,
50-
device: Union[str, torch.device] = "cpu",
5150
):
5251
freqs = 1.0 / (
53-
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
52+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
5453
)
55-
t = torch.arange(end, device=freqs.device) # pyre-ignore
54+
t = torch.arange(end)
5655
if use_scaled:
5756
assert scale_factor is not None
5857
freqs = apply_scaling(freqs, scale_factor, high_freq_factor) # pyre-ignore

0 commit comments

Comments
 (0)