|
4 | 4 | from torch import nn |
5 | 5 | from transformers import PretrainedConfig |
6 | 6 |
|
7 | | -from tensorrt_llm.functional import PositionEmbeddingType |
| 7 | +from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType |
8 | 8 | from tensorrt_llm.lora_manager import HfLoraLoader |
9 | 9 | from tensorrt_llm.models.convert_utils import split_matrix_tp |
10 | 10 |
|
@@ -48,19 +48,28 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig], |
48 | 48 |
|
49 | 49 |
|
50 | 50 | class NemotronNASAttention(Attention): |
| 51 | + NON_NEOX_TYPES = ("mistral_yarn", "rope_llama4") |
51 | 52 |
|
52 | 53 | def __init__(self, model_config: ModelConfig[PretrainedConfig], |
53 | 54 | layer_idx: int): |
54 | 55 | config = model_config.pretrained_config |
| 56 | + is_neox = getattr(model_config.pretrained_config, |
| 57 | + "position_embedding_type", |
| 58 | + None) not in self.NON_NEOX_TYPES |
| 59 | + rope = RopeParams.from_config(config) |
| 60 | + if rope.scale_type == RotaryScalingType.yarn: |
| 61 | + rope.mscale_all_dim = 0.0 |
| 62 | + |
55 | 63 | super().__init__( |
56 | 64 | hidden_size=config.hidden_size, |
57 | 65 | num_attention_heads=config.num_attention_heads, |
58 | 66 | num_key_value_heads=config.num_key_value_heads[layer_idx], |
59 | 67 | max_position_embeddings=config.max_position_embeddings, |
60 | 68 | bias=False, |
61 | 69 | pos_embd_params=PositionalEmbeddingParams( |
62 | | - type=PositionEmbeddingType.rope_gpt_neox, |
63 | | - rope=RopeParams.from_config(config), |
| 70 | + type=PositionEmbeddingType.rope_gpt_neox |
| 71 | + if is_neox else PositionEmbeddingType.rope_gptj, |
| 72 | + rope=rope, |
64 | 73 | ), |
65 | 74 | layer_idx=layer_idx, |
66 | 75 | dtype=config.torch_dtype, |
|
0 commit comments