Skip to content

Commit de97799

Browse files
authored
feat: Add support for YARN in NemotronNAS models (#4906)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
1 parent a985c0b commit de97799

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,7 @@ def create_rope_const_params(self, interleave: bool = True):
391391
)
392392

393393
if self.scale_type == RotaryScalingType.yarn:
394-
rope_inv_freq = None
395-
_, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
394+
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
396395
self.max_positions,
397396
self.dim,
398397
self.theta,

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
self.qk_rope_head_dim = None
111111
self.v_head_dim = None
112112

113-
self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_rope_const_params(
113+
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
114114
)
115115

116116
self.num_heads = num_heads

tensorrt_llm/_torch/models/modeling_nemotron_nas.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55
from transformers import PretrainedConfig
66

7-
from tensorrt_llm.functional import PositionEmbeddingType
7+
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
88
from tensorrt_llm.lora_manager import HfLoraLoader
99
from tensorrt_llm.models.convert_utils import split_matrix_tp
1010

@@ -48,19 +48,28 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
4848

4949

5050
class NemotronNASAttention(Attention):
51+
NON_NEOX_TYPES = ("mistral_yarn", "rope_llama4")
5152

5253
def __init__(self, model_config: ModelConfig[PretrainedConfig],
5354
layer_idx: int):
5455
config = model_config.pretrained_config
56+
is_neox = getattr(model_config.pretrained_config,
57+
"position_embedding_type",
58+
None) not in self.NON_NEOX_TYPES
59+
rope = RopeParams.from_config(config)
60+
if rope.scale_type == RotaryScalingType.yarn:
61+
rope.mscale_all_dim = 0.0
62+
5563
super().__init__(
5664
hidden_size=config.hidden_size,
5765
num_attention_heads=config.num_attention_heads,
5866
num_key_value_heads=config.num_key_value_heads[layer_idx],
5967
max_position_embeddings=config.max_position_embeddings,
6068
bias=False,
6169
pos_embd_params=PositionalEmbeddingParams(
62-
type=PositionEmbeddingType.rope_gpt_neox,
63-
rope=RopeParams.from_config(config),
70+
type=PositionEmbeddingType.rope_gpt_neox
71+
if is_neox else PositionEmbeddingType.rope_gptj,
72+
rope=rope,
6473
),
6574
layer_idx=layer_idx,
6675
dtype=config.torch_dtype,

0 commit comments

Comments
 (0)