From 047fc9945ef8afd4f23ecbd8da831b27a4f36026 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 19 Sep 2025 16:46:10 +0800 Subject: [PATCH 1/4] fope --- .../backends/default/rotary_embedding.py | 71 ++++++++++- lmdeploy/pytorch/backends/rotary_embedding.py | 15 ++- lmdeploy/pytorch/nn/rotary_embedding.py | 115 +++++++++++++++++- 3 files changed, 194 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index 5e295f0923..bbccdc1061 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -3,10 +3,11 @@ import math import torch +import torch.nn.functional as F from torch import nn -from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, RotaryEmbeddingBuilder, - RotaryEmbeddingImpl, YarnParameters) +from ..rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType, + RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters) def _rotary_embedding_fwd(position_ids: torch.Tensor, @@ -270,6 +271,64 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor): device_type=device) +class FopeRotaryEmbeddingImpl(RotaryEmbeddingImpl): + + def __init__(self, + dim: int, + max_position_embeddings: int = 4096, + scaling_factor: float = 1.0, + params: FopeParameters = None): + super().__init__(dim, scaling_factor=scaling_factor) + self.head_dim = dim + self.max_position_embeddings = max_position_embeddings + self.attention_scaling = scaling_factor + self.params = params + + inv_freq = self.params.inv_freq + inv_freq_idx_selected = inv_freq > 2 * torch.pi / self.max_position_embeddings + if self.params.num_inv_freq is not None and inv_freq_idx_selected.sum() > (inv_freq.shape[-1] - + self.params.num_inv_freq): + inv_freq_idx_selected[-self.params.num_inv_freq:] = False + self.inv_freq = inv_freq[inv_freq_idx_selected] + self.register_buffer('inv_freq', self.inv_freq, persistent=False) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor, sin_coef: torch.Tensor, cos_coef: torch.Tensor): + """forward.""" + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + + inv_freq = self.inv_freq + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + + batch_size, seq_len, _ = x.shape + if self.params.fope_sep_head: + pos_cos = freqs.cos().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1) + pos_sin = freqs.sin().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1) + else: + pos_cos = freqs.cos() + pos_sin = freqs.sin() + + if self.params.fope_sep_head: + sin = torch.einsum('bhtD, hDd -> bthd', pos_sin, sin_coef.float()) + cos = torch.einsum('bhtD, hDd -> bthd', pos_cos, cos_coef.float()) + else: + sin = torch.einsum('btD, Dd -> btd', pos_sin, sin_coef.float()) + cos = torch.einsum('btD, Dd -> btd', pos_cos, cos_coef.float()) + + sin = F.pad(input=sin, pad=(0, self.head_dim // 2 - sin.size(-1)), mode='constant', value=1) + cos = F.pad(input=cos, pad=(0, self.head_dim // 2 - cos.size(-1)), mode='constant', value=1) + + sin = torch.cat((sin, sin), dim=-1) + cos = torch.cat((cos, cos), dim=-1) + + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): """Rotary embedding builder.""" @@ -282,6 +341,7 @@ def build( yarn_params: YarnParameters = None, longrope_params: LongRoPEScalingParameters = None, llama3_params: Llama3Parameters = None, + fope_params: FopeParameters = None, emb_type: RopeType = RopeType.Default, ): """build.""" @@ -302,5 +362,12 @@ def build( max_position_embeddings=max_position_embeddings, longrope_params=longrope_params, ) + elif emb_type == RopeType.Fope: + return FopeRotaryEmbeddingImpl( + dim, + max_position_embeddings=max_position_embeddings, + scaling_factor=scaling_factor, + params=fope_params, + ) else: raise NotImplementedError(f'Unsupported embedding type: {emb_type}') diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index 7b6f9f7eab..7fb3b39c7b 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -4,6 +4,8 @@ from enum import Enum, auto from typing import List +import torch + class RopeType(Enum): """Rotary embedding type.""" @@ -13,6 +15,7 @@ class RopeType(Enum): Llama3 = auto() Yarn = auto() LongRoPEScaling = auto() + Fope = auto() @dataclass @@ -43,11 +46,20 @@ class Llama3Parameters: original_max_position_embeddings: int = 8192 +@dataclass +class FopeParameters: + """Fope parameters.""" + num_inv_freq: int = None + num_key_value_heads: int = 1 + fope_sep_head: bool = False + inv_freq: torch.Tensor = None + + class RotaryEmbeddingImpl(ABC): """Rotary embedding implementation api.""" @abstractmethod - def forward(self, x, position_ids): + def forward(self, x, position_ids, **kwargs): """forward.""" raise NotImplementedError @@ -65,6 +77,7 @@ def build( yarn_params: YarnParameters = None, longrope_params: LongRoPEScalingParameters = None, llama3_params: Llama3Parameters = None, + fope_params: FopeParameters = None, emb_type: RopeType = RopeType.Default, ): """build.""" diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index a756f7c16d..90d71eed19 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -2,11 +2,13 @@ import math +import torch from torch import Tensor, nn from transformers import PretrainedConfig from ..backends import OpType, get_backend -from ..backends.rotary_embedding import Llama3Parameters, LongRoPEScalingParameters, RopeType, YarnParameters +from ..backends.rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType, + YarnParameters) def _get_default_rope_parameters(config: PretrainedConfig): @@ -92,6 +94,15 @@ def _get_llama3_parameters(config: PretrainedConfig): return dict(emb_type=RopeType.Llama3, scaling_factor=scaling_factor, llama3_params=params) +def _get_fope_parameters(config: PretrainedConfig): + """Get fope parameters.""" + params = FopeParameters() + params.num_inv_freq = config.num_inv_freq + params.num_key_value_heads = config.num_key_value_heads + params.fope_sep_head = config.fope_sep_head + return dict(use_fope=True, fope_params=params) + + def build_rotary_params(config: PretrainedConfig): """Get scaling_factor rotary params, and emb_type.""" params = dict(emb_type=RopeType.Default) @@ -114,6 +125,9 @@ def build_rotary_params(config: PretrainedConfig): if partial_rotary_factor is not None: params['partial_rotary_factor'] = partial_rotary_factor + if getattr(config, 'use_fope', False): + params.update(_get_fope_parameters(config)) + return params @@ -124,8 +138,10 @@ def build_rotary_embedding(dim: int, yarn_params: YarnParameters = None, longrope_params: LongRoPEScalingParameters = None, llama3_params: Llama3Parameters = None, + fope_params: FopeParameters = None, emb_type: RopeType = RopeType.Default, - partial_rotary_factor: float = None) -> nn.Module: + partial_rotary_factor: float = None, + use_fope: bool = False) -> nn.Module: """Build rotary embedding op.""" backend = get_backend() @@ -134,7 +150,7 @@ def build_rotary_embedding(dim: int, # update rope_dim if partial_rotary_factor is not None: dim = int(dim * partial_rotary_factor) - return builder.build(dim, + impl = builder.build(dim, max_position_embeddings, base, scaling_factor, @@ -143,6 +159,15 @@ def build_rotary_embedding(dim: int, llama3_params=llama3_params, emb_type=emb_type) + if use_fope: + assert fope_params is not None, 'fope_params should not be None when use_fope is True.' + inv_freq = impl.inv_freq + fope_params.inv_freq = inv_freq + fope = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params) + return fope + + return impl + def build_rotary_embedding_from_config(config: PretrainedConfig) -> nn.Module: """Build rotary embedding op from config.""" @@ -169,4 +194,86 @@ def __init__(self): def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True): """forward.""" - return self.impl.forward(query, key, cos, sin, inplace) + + assert query.dim() == key.dim() == 3, 'Expected query key (seq_len, heads, head_dim)' + assert cos.dim() <= 3 and sin.dim() <= 3 + + need_reshape = False + if cos.dim() == 3: + # for fope + need_reshape = True + query_shape = query.shape + key_shape = key.shape + cos = cos.flatten(0, 1) + sin = sin.flatten(0, 1) + seq_len = cos.size(0) + query = query.view(seq_len, -1, query.size(-1)) + key = key.view(seq_len, -1, key.size(-1)) + + query, key = self.impl.forward(query, key, cos, sin, inplace) + + if need_reshape: + query = query.view(query_shape) + key = key.view(key_shape) + return query, key + + +class FopeRotaryEmbedding(nn.Module): + """Fope rotary embedding.""" + + def __init__(self, dim: int, max_position_embeddings: int, attention_scaling: float, params: FopeParameters): + super().__init__() + + num_key_value_heads, tp = self.update_num_kv_heads(params.num_key_value_heads) + self.tp = tp + params.num_key_value_heads = num_key_value_heads + + # build impl + backend = get_backend() + builder = backend.get_layer_impl_builder(OpType.RotaryEmbedding) + self.impl = builder.build(dim, + max_position_embeddings=max_position_embeddings, + scaling_factor=attention_scaling, + fope_params=params, + emb_type=RopeType.Fope) + + # setup params + inv_freq = self.impl.inv_freq + self.input_dim = inv_freq.shape[-1] + self.output_dim = inv_freq.shape[-1] + self.cos_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim), + requires_grad=False) + self.sin_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim), + requires_grad=False) + if self.tp: + self.cos_coef.weight_loader = self.weight_loader + self.sin_coef.weight_loader = self.weight_loader + + @staticmethod + def update_num_kv_heads(num_key_value_heads: int): + """Update num_key_value_heads.""" + from lmdeploy.pytorch.distributed import get_dist_manager + dist_mgr = get_dist_manager() + dist_ctx = dist_mgr.current_context() + tp = dist_ctx.dist_config.attn_config.tp + if tp > 1: + num_key_value_heads = max(1, num_key_value_heads // tp) + return num_key_value_heads, tp + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + """Weight loader.""" + from lmdeploy.pytorch.distributed import get_tp_world_rank + world_size, rank = get_tp_world_rank() + num_key_value_heads = loaded_weight.size(0) + + if num_key_value_heads < world_size: + n_replicate = world_size // num_key_value_heads + world_size = num_key_value_heads + rank = rank // n_replicate + + loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank] + param.copy_(loaded_weight) + + def forward(self, x: Tensor, position_ids: Tensor): + """forward.""" + return self.impl.forward(x, position_ids, sin_coef=self.sin_coef, cos_coef=self.cos_coef) From e9bd8346387e54767585de134a229427278e605a Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 13 Nov 2025 18:03:39 +0800 Subject: [PATCH 2/4] update config format --- lmdeploy/pytorch/nn/rotary_embedding.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 90d71eed19..02d4336770 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -97,10 +97,11 @@ def _get_llama3_parameters(config: PretrainedConfig): def _get_fope_parameters(config: PretrainedConfig): """Get fope parameters.""" params = FopeParameters() - params.num_inv_freq = config.num_inv_freq + rope_scaling = config.rope_scaling + params.num_inv_freq = rope_scaling['num_inv_freq'] params.num_key_value_heads = config.num_key_value_heads - params.fope_sep_head = config.fope_sep_head - return dict(use_fope=True, fope_params=params) + params.fope_sep_head = rope_scaling['fope_sep_head'] + return dict(fope_params=params) def build_rotary_params(config: PretrainedConfig): @@ -111,6 +112,9 @@ def build_rotary_params(config: PretrainedConfig): if rope_scaling is not None: # BC: "rope_type" was originally "type" rope_type_str = config.rope_scaling.get('rope_type', config.rope_scaling.get('type', 'default')) + if rope_type_str.startswith('fope'): + params.update(_get_fope_parameters(config)) + rope_type_str = 'default' if rope_type_str == 'fope' else rope_type_str[5:] build_funcs = dict(default=_get_default_rope_parameters, linear=_get_linear_scaling_rope_parameters, dynamic=_get_dynamic_ntk_parameters, @@ -125,9 +129,6 @@ def build_rotary_params(config: PretrainedConfig): if partial_rotary_factor is not None: params['partial_rotary_factor'] = partial_rotary_factor - if getattr(config, 'use_fope', False): - params.update(_get_fope_parameters(config)) - return params @@ -140,8 +141,7 @@ def build_rotary_embedding(dim: int, llama3_params: Llama3Parameters = None, fope_params: FopeParameters = None, emb_type: RopeType = RopeType.Default, - partial_rotary_factor: float = None, - use_fope: bool = False) -> nn.Module: + partial_rotary_factor: float = None) -> nn.Module: """Build rotary embedding op.""" backend = get_backend() @@ -159,8 +159,7 @@ def build_rotary_embedding(dim: int, llama3_params=llama3_params, emb_type=emb_type) - if use_fope: - assert fope_params is not None, 'fope_params should not be None when use_fope is True.' + if fope_params is not None: inv_freq = impl.inv_freq fope_params.inv_freq = inv_freq fope = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params) From 807e733e21f09baa97971df86d475bcc4893bd7e Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 18 Nov 2025 12:58:26 +0800 Subject: [PATCH 3/4] update fope params --- lmdeploy/pytorch/nn/rotary_embedding.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 02d4336770..19a3ce0834 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -96,9 +96,16 @@ def _get_llama3_parameters(config: PretrainedConfig): def _get_fope_parameters(config: PretrainedConfig): """Get fope parameters.""" + # check if fope is used + rope_scaling = getattr(config, 'rope_scaling', dict()) + fope_keys = ['fope_sep_head', 'fope_num_inv_freq'] + is_fope = any(key in rope_scaling for key in fope_keys) + if not is_fope: + return dict() + params = FopeParameters() rope_scaling = config.rope_scaling - params.num_inv_freq = rope_scaling['num_inv_freq'] + params.num_inv_freq = rope_scaling.get('fope_num_inv_freq', rope_scaling.get('num_inv_freq', params.num_inv_freq)) params.num_key_value_heads = config.num_key_value_heads params.fope_sep_head = rope_scaling['fope_sep_head'] return dict(fope_params=params) @@ -112,9 +119,8 @@ def build_rotary_params(config: PretrainedConfig): if rope_scaling is not None: # BC: "rope_type" was originally "type" rope_type_str = config.rope_scaling.get('rope_type', config.rope_scaling.get('type', 'default')) - if rope_type_str.startswith('fope'): - params.update(_get_fope_parameters(config)) - rope_type_str = 'default' if rope_type_str == 'fope' else rope_type_str[5:] + if rope_type_str == 'fope': + rope_type_str = 'default' build_funcs = dict(default=_get_default_rope_parameters, linear=_get_linear_scaling_rope_parameters, dynamic=_get_dynamic_ntk_parameters, @@ -123,6 +129,7 @@ def build_rotary_params(config: PretrainedConfig): su=_get_longrope_parameters, llama3=_get_llama3_parameters) params.update(build_funcs[rope_type_str](config)) + params.update(_get_fope_parameters(config)) # update partial_rotary_factor partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else None From 438a22ffc705bf3c714cb060e84554e6a7a60c84 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 20 Nov 2025 11:09:36 +0800 Subject: [PATCH 4/4] merge main --- lmdeploy/pytorch/nn/rotary_embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 19a3ce0834..9ff2fcba73 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -261,7 +261,8 @@ def update_num_kv_heads(num_key_value_heads: int): from lmdeploy.pytorch.distributed import get_dist_manager dist_mgr = get_dist_manager() dist_ctx = dist_mgr.current_context() - tp = dist_ctx.dist_config.attn_config.tp + tp = dist_ctx.dist_config.attn_tp + # tp = dist_ctx.dist_config.attn_config.tp if tp > 1: num_key_value_heads = max(1, num_key_value_heads // tp) return num_key_value_heads, tp