Skip to content

Commit 2681b26

Browse files
[TRTLLM-2795] feat: Add yarn support for other models in trt-flow (#3840)
Add yarn support for general models(e.g. llama, qwen) other than deepseek in trt-flow. Signed-off-by: Zeyu Wang <[email protected]>
1 parent f9adac3 commit 2681b26

File tree

4 files changed

+42
-11
lines changed

4 files changed

+42
-11
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def create_rope_const_params(self, interleave: bool = True):
393393

394394
if self.scale_type == RotaryScalingType.yarn:
395395
rope_inv_freq = None
396-
rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
396+
_, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
397397
self.max_positions,
398398
self.dim,
399399
self.theta,

tensorrt_llm/functional.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4805,6 +4805,7 @@ def create_sinusoidal_positions_yarn(
48054805
beta_slow: int = 1,
48064806
mscale: float = 1.0,
48074807
mscale_all_dim: float = 1.0,
4808+
duplicate_data: bool = True,
48084809
dtype=np.float32):
48094810

48104811
# Copy from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py
@@ -4862,23 +4863,25 @@ def yarn_linear_ramp_mask(min, max, dim):
48624863
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high,
48634864
dim // 2).astype(dtype)
48644865
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
4865-
t = np.arange(num_pos, dtype=dtype)
4866-
4867-
freqs = np.outer(t, inv_freq)
4866+
sinusoid_inp = np.expand_dims(np.einsum("i , j -> i j",
4867+
np.arange(num_pos, dtype=dtype),
4868+
inv_freq,
4869+
dtype=dtype),
4870+
axis=-1)
48684871

48694872
_mscale = float(
48704873
yarn_get_mscale(scaling_factor, mscale) /
48714874
yarn_get_mscale(scaling_factor, mscale_all_dim))
48724875

4873-
emb = np.concatenate((freqs, freqs), axis=-1)
4876+
if duplicate_data:
4877+
emb = np.concatenate((sinusoid_inp, sinusoid_inp), axis=-2)
4878+
else:
4879+
emb = sinusoid_inp
48744880

48754881
concat = np.concatenate((np.cos(emb) * _mscale, np.sin(emb) * _mscale),
48764882
axis=-1)
48774883

4878-
concat = concat.reshape((num_pos, 2, dim))
4879-
concat = np.transpose(concat, (0, 2, 1))
4880-
4881-
return concat.reshape((1, -1)).astype(dtype)
4884+
return inv_freq, concat.reshape((1, -1)).astype(dtype)
48824885

48834886
@staticmethod
48844887
def rotate_every_two(tensor: Tensor) -> Tensor:

tensorrt_llm/layers/attention.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,34 @@ def create_attention_const_params(model_cls, config):
672672
is_buffer=True))
673673
model_cls.short_mscale = short_mscale
674674
model_cls.long_mscale = long_mscale
675+
elif rotary_embedding_scale_type == RotaryScalingType.yarn:
676+
beta_fast = rotary_embedding_scaling.get("beta_fast", 32.0)
677+
beta_slow = rotary_embedding_scaling.get("beta_slow", 1.0)
678+
mscale = rotary_embedding_scaling.get("mscale", 1.0)
679+
mscale_all_dim = rotary_embedding_scaling.get("mscale_all_dim", 0.0)
680+
original_max_position_embeddings = rotary_embedding_scaling.get(
681+
"original_max_position_embeddings", 4096)
682+
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
683+
max_position_embeddings, rotary_embedding_dim,
684+
rotary_embedding_base, rotary_embedding_scale,
685+
original_max_position_embeddings, beta_fast, beta_slow, mscale,
686+
mscale_all_dim, False)
687+
688+
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
689+
max_position_embeddings,
690+
rotary_embedding_dim,
691+
)
692+
model_cls.register_parameter(
693+
'embed_positions',
694+
Parameter(embed_positions, dtype='float32', is_buffer=True))
695+
model_cls.register_parameter(
696+
'rotary_inv_freq',
697+
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
698+
model_cls.register_parameter(
699+
'embed_positions_for_gpt_attention',
700+
Parameter(embed_positions_for_gpt_attention,
701+
dtype='float32',
702+
is_buffer=True))
675703
else:
676704

677705
def register_rope_params(rotary_base, names_to_register):
@@ -2052,7 +2080,7 @@ def yarn_get_mscale(scale=1, mscale=1):
20522080
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
20532081
self.q_scaling = 1.0 / (mscale * mscale)
20542082

2055-
embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
2083+
_, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
20562084
self.max_position_embeddings, self.qk_rope_head_dim,
20572085
self.rotary_embedding_base, self.rotary_scaling["factor"],
20582086
rotary_embedding_origin_max_position, rotary_embedding_beta_fast,

tests/unittest/_torch/test_attention_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers,
492492
rope_config.rope_scaling['beta_slow'],
493493
rope_config.rope_scaling['mscale'],
494494
rope_config.rope_scaling['mscale_all_dim'],
495-
),
495+
)[1],
496496
dtype=torch.float32,
497497
device=device,
498498
).reshape(rope_config.max_position_embeddings, -1, 2).transpose(-2, -1)

0 commit comments

Comments
 (0)