diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 04d29f91ac6..3f9d3d8f2af 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -63,6 +63,9 @@ class ModelArgs: use_sdpa_with_kv_cache_op: bool = ( False # Use custom sdpa op that updates kv cache in-place ) + # Device to use for the model: "cpu" or "cuda" (needed for QAT) + # Only used for creating Rope parameters + device: str = "cpu" # Generate logits for all inputs. When it's True, it would take big memory usage # at runtime. Enable it only necessary (e.g., use perplexity tools that requires # logits for all input tokens.) diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 8c0d5db6a80..0d1dd306091 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -138,7 +138,11 @@ def forward( # and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242. # Current only support non-long rope. def hf_precompute_freqs_cis( - dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0 + dim: int, + end: int, + theta: float, + partial_rotary_factor: float = 1.0, + device: Union[str, torch.device] = "cpu", ): # Partial rotary embeddings. dim = int(dim * partial_rotary_factor) @@ -146,7 +150,7 @@ def hf_precompute_freqs_cis( # Short factor scaling. freqs = 1.0 / ( theta - ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim) + ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim) ) # TODO: support long factor scaling. @@ -236,6 +240,7 @@ def __init__(self, params: ModelArgs): self.precompute_freqs_cis = partial( hf_precompute_freqs_cis, partial_rotary_factor=self.params.partial_rotary_factor, + device=self.params.device, ) self.apply_rotary_emb = hf_apply_rotary_emb else: @@ -244,6 +249,7 @@ def __init__(self, params: ModelArgs): use_scaled=self.params.use_scaled_rope, scale_factor=self.params.rope_scale_factor, high_freq_factor=self.params.high_freq_factor, + device=self.params.device, ) self.apply_rotary_emb = RotaryEmbedding()