diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index 7ead28cf9d1..f4eea62b6e9 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -400,6 +400,7 @@ from .qwen2_moe import * from .qwen3 import * from .qwen3_moe import * + from .gpt_oss import * else: sys.modules[__name__] = _LazyModule( __name__, diff --git a/paddleformers/transformers/auto/configuration.py b/paddleformers/transformers/auto/configuration.py index 66294d4694d..d55c96e3026 100644 --- a/paddleformers/transformers/auto/configuration.py +++ b/paddleformers/transformers/auto/configuration.py @@ -44,6 +44,7 @@ ("qwen2_moe", "Qwen2MoeConfig"), ("qwen3", "Qwen3Config"), ("qwen3_moe", "Qwen3MoeConfig"), + ("gpt_oss", "GptOssConfig"), ] ) diff --git a/paddleformers/transformers/auto/modeling.py b/paddleformers/transformers/auto/modeling.py index 83a23de0a20..f3112dc9b9b 100644 --- a/paddleformers/transformers/auto/modeling.py +++ b/paddleformers/transformers/auto/modeling.py @@ -62,6 +62,7 @@ ("Qwen3", "qwen3"), ("Qwen2Moe", "qwen2_moe"), ("Qwen3Moe", "qwen3_moe"), + ("GptOss", "gpt_oss"), ] ) diff --git a/paddleformers/transformers/gpt_oss/__init__.py b/paddleformers/transformers/gpt_oss/__init__.py new file mode 100644 index 00000000000..6709d0167aa --- /dev/null +++ b/paddleformers/transformers/gpt_oss/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration import * +from .modeling import * diff --git a/paddleformers/transformers/gpt_oss/configuration.py b/paddleformers/transformers/gpt_oss/configuration.py new file mode 100644 index 00000000000..02489689edb --- /dev/null +++ b/paddleformers/transformers/gpt_oss/configuration.py @@ -0,0 +1,122 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from ..configuration_utils import PretrainedConfig, layer_type_validation +from ..configuration_utils import PretrainedConfig + +# from ...modeling_rope_utils import rope_config_validation + + +class GptOssConfig(PretrainedConfig): + r""" + This will yield a configuration to that of the BERT + [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture. + """ + + model_type = "gpt_oss" + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.mlp.experts": "gather", + "layers.*.mlp.router": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", + } + + def __init__( + self, + num_hidden_layers: int = 24, + num_local_experts: int = 128, + vocab_size: int = 201088, + hidden_size: int = 2880, + intermediate_size: int = 2880, + head_dim: int = 64, + num_attention_heads: int = 64, + num_key_value_heads: int = 8, + sliding_window: int = 128, + rope_theta: float = 150000.0, + tie_word_embeddings=False, + hidden_act: str = "silu", + initializer_range: float = 0.02, + max_position_embeddings=131072, + rms_norm_eps: float = 1e-5, + rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False}, + attention_dropout: float = 0.0, + num_experts_per_tok=4, + router_aux_loss_coef: float = 0.9, + output_router_logits=False, + use_cache=True, + layer_types=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_experts = num_local_experts + self.sliding_window = sliding_window + self.num_experts_per_tok = num_experts_per_tok + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + ] + # layer_type_validation(self.layer_types) + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # rope_config_validation(self) + + self.attention_bias = True + self.max_position_embeddings = max_position_embeddings + self.router_aux_loss_coef = router_aux_loss_coef + self.output_router_logits = output_router_logits + self.use_cache = use_cache + self.fuse_rope = False + self.fuse_linear = False + self.use_bias = False + self.compression_ratio = 1 + self.cachekv_quant = False + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["GptOssConfig"] diff --git a/paddleformers/transformers/gpt_oss/modeling.py b/paddleformers/transformers/gpt_oss/modeling.py new file mode 100644 index 00000000000..db6efc69cc1 --- /dev/null +++ b/paddleformers/transformers/gpt_oss/modeling.py @@ -0,0 +1,1388 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial +from typing import List, Optional, Tuple, Union + +import paddle +from paddle import Tensor, nn +from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from paddle.nn import functional as F + +from ...nn.attention.utils import repeat_kv +from ...nn.criterion.interface import CriterionLayer +from ...nn.embedding import Embedding as GeneralEmbedding +from ...nn.linear import Linear as GeneralLinear +from ...nn.lm_head import LMHead as GeneralLMHead +from ...utils.log import logger +from ...utils.tools import get_env_device +from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..llama.modeling import get_use_casual_mask +from ..model_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPast +from ..model_utils import PretrainedModel, register_base_model +from ..qwen2.modeling import _expand_2d_mask, _make_causal_mask, is_casual_mask +from .configuration import GptOssConfig + + +class GptOssRMSNorm(nn.Layer): + def __init__(self, hidden_size, eps=1e-6): + """ + GptOssRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = self.create_parameter( + shape=[hidden_size], default_initializer=paddle.nn.initializer.Constant(1.0) + ) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(paddle.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * paddle.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) # main diff with Llama + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GptOssExperts(nn.Layer): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_experts + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = paddle.create_parameter( + shape=[self.num_experts, self.hidden_size, 2 * self.expert_dim], + dtype="bfloat16", + default_initializer=paddle.nn.initializer.Uniform(), + ) + self.gate_up_proj_bias = paddle.create_parameter( + shape=[self.num_experts, 2 * self.expert_dim], + dtype="bfloat16", + default_initializer=paddle.nn.initializer.Uniform(), + ) + self.down_proj = paddle.create_parameter( + shape=[self.num_experts, self.expert_dim, self.hidden_size], + dtype="bfloat16", + default_initializer=paddle.nn.initializer.Uniform(), + ) + self.down_proj_bias = paddle.create_parameter( + shape=[self.num_experts, self.hidden_size], + dtype="bfloat16", + default_initializer=paddle.nn.initializer.Uniform(), + ) + self.alpha = 1.702 + self.limit = 7.0 + + def forward(self, hidden_states: paddle.Tensor, router_indices=None, routing_weights=None) -> paddle.Tensor: + """ + When training is is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + Args: + hidden_states (paddle.Tensor): (batch_size, seq_len, hidden_size) + selected_experts (paddle.Tensor): (batch_size * token_num, top_k) + routing_weights (paddle.Tensor): (batch_size * token_num, num_experts) + Returns: + paddle.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape([-1, self.hidden_size]) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + if self.training: + next_states = paddle.zeros_like(hidden_states, dtype=hidden_states.dtype) + with paddle.no_grad(): + expert_mask = F.one_hot(router_indices, num_classes=num_experts) + expert_mask = expert_mask.transpose(perm=[2, 1, 0]) + # we sum on the top_k and on the sequence lenght to get which experts + # are hit this time around + expert_hitted = paddle.nonzero( + paddle.greater_than(expert_mask.sum(axis=(-1, -2)), paddle.to_tensor(0, dtype=expert_mask.dtype)) + ) + for expert_idx in expert_hitted[:]: + with paddle.no_grad(): + _, token_idx = paddle.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = paddle.clip(gate, min=None, max=self.limit) + up = paddle.clip(up, min=-self.limit, max=self.limit) + glu = gate * F.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] + next_states = paddle.index_add( + next_states, + token_idx, + 0, + weighted_output.astype(hidden_states.dtype), + ) + next_states = next_states.reshape([batch_size, -1, self.hidden_size]) + else: + hidden_states = paddle.tile(hidden_states, repeat_times=[num_experts, 1]) + hidden_states = hidden_states.reshape((num_experts, -1, self.hidden_size)) + gate_up = paddle.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = paddle.clip(gate, min=None, max=self.limit) + up = paddle.clip(up, min=-self.limit, max=self.limit) + glu = gate * F.sigmoid(gate * self.alpha) + next_states = paddle.bmm(((up + 1) * glu), self.down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.reshape((num_experts, batch_size, -1, self.hidden_size)) + next_states = ( + next_states * routing_weights.transpose([0, 1]).reshape((num_experts, batch_size, -1))[..., None] + ) + next_states = next_states.sum(axis=0) + return next_states + + +class GptOssTopKRouter(nn.Layer): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.num_experts, self.hidden_dim], + dtype="bfloat16", + default_initializer=paddle.nn.initializer.Uniform(), + ) + self.bias = paddle.create_parameter( + shape=[self.num_experts], dtype="bfloat16", default_initializer=paddle.nn.initializer.Uniform() + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape([-1, self.hidden_dim]) + router_logits = F.linear(hidden_states, self.weight.t(), self.bias) # (seq_len, num_experts) + router_top_value, router_indices = paddle.topk(router_logits, self.top_k, axis=-1) # (seq_len, top_k) + router_top_value = F.softmax(router_top_value, axis=1, dtype=router_top_value.dtype) + router_scores = paddle.zeros_like(router_logits) + router_scores = paddle.put_along_axis(router_scores, router_indices, router_top_value, axis=1) + return router_scores, router_indices + + +class GptOssMLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.router = GptOssTopKRouter(config) + self.experts = GptOssExperts(config) + + def forward(self, hidden_states): + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out, router_scores + + +def _compute_yarn_parameters(config, device: paddle.device, seq_len: Optional[int] = None) -> tuple[Tensor, float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://huggingface.co/papers/2309.00071) + Args: + config: The model configuration. + device: The device to use for initialization of the inverse frequencies. + seq_len: The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (Tensor, float), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + factor = config.rope_scaling["factor"] + attention_factor = config.rope_scaling.get("attention_factor") + mscale = config.rope_scaling.get("mscale") + mscale_all_dim = config.rope_scaling.get("mscale_all_dim") + + # 处理原始最大位置嵌入的情况 + if "original_max_position_embeddings" in config.rope_scaling: + original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + # 设置注意力因子 + if attention_factor is None: + if mscale and mscale_all_dim: + attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) + else: + attention_factor = get_mscale(factor) + + # 可选配置参数 + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # 计算逆频率的辅助函数 + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """基于旋转次数计算维度的逆维度公式""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate): + """基于旋转找到维度范围边界""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # 防止奇点 + + # 使用arange创建线性函数 + linear_func = (paddle.arange(dim, dtype=paddle.float32) - min_val) / (max_val - min_val) + ramp_func = paddle.clip(linear_func, 0, 1) + return ramp_func + + # 计算位置频率 + # 在Paddle中指定设备和数据类型 + pos_freqs = base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + truncate = config.rope_scaling.get("truncate", True) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate) + + # 获取用于外推的n维旋转缩放校正 + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=paddle.float32) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + return inv_freq, attention_factor + + +class GptOssRotaryEmbedding(nn.Layer): + def __init__(self, config: GptOssConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = _compute_yarn_parameters + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + # todo : inv_freq会不会变? + self.inv_freq = self.create_parameter( + shape=inv_freq.shape, + dtype=inv_freq.dtype, + default_initializer=paddle.nn.initializer.Assign(inv_freq), + ) + self.inv_freq.stop_gradient = True + self.original_inv_freq = self.inv_freq + + @paddle.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq.unsqueeze(0) + .unsqueeze(-1) + .cast(paddle.float32) + .expand([position_ids.shape[0], -1, 1]) + .to(x.place) + ) + position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32) + + freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1]) + + emb = freqs + cos = paddle.cos(emb) * self.attention_scaling + sin = paddle.sin(emb) * self.attention_scaling + return cos.cast(x.dtype), sin.cast(x.dtype) + + +def _apply_rotary_emb( + x: paddle.Tensor, + cos: paddle.Tensor, + sin: paddle.Tensor, +) -> paddle.Tensor: + first_half, second_half = paddle.chunk(x.transpose([0, 2, 1, 3]), 2, axis=-1) + first_ = first_half * cos - second_half * sin + second_ = second_half * cos + first_half * sin + return paddle.concat((first_, second_), axis=-1).transpose([0, 2, 1, 3]) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = _apply_rotary_emb(q, cos, sin) + k_embed = _apply_rotary_emb(k, cos, sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: nn.Layer, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if hasattr(module, "num_key_value_groups"): + num_key_value_groups = module.num_key_value_groups + + key = repeat_kv(key, num_key_value_groups) + value = repeat_kv(value, num_key_value_groups) + + perm = [0, 2, 1, 3] # b l h d -> b h l d + query = paddle.transpose(x=query, perm=perm) + key = paddle.transpose(x=key, perm=perm) + value = paddle.transpose(x=value, perm=perm) + + attn_weights = paddle.matmul(query, key.transpose([0, 1, 3, 2])) * scaling + + if attention_mask is not None: + + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + sinks = module.sinks.reshape([1, -1, 1, 1]).expand([query.shape[0], -1, query.shape[-2], -1]) + + combined_logits = paddle.concat([attn_weights, sinks], axis=-1) + + probs = F.softmax(combined_logits, axis=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = paddle.matmul(attn_weights, value) # b h l l @ b h l d -> b h l d + attn_output = attn_output.transpose([0, 2, 1, 3]) # b h l d -> b l h d + attn_output = paddle.reshape(x=attn_output, shape=[0, 0, attn_output.shape[2] * attn_output.shape[3]]) + + return attn_output, attn_weights + + +class GptOssAttention(nn.Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config, layer_idx=0): + """Initialize the attention layer. + + Args: + config (GptOssConfig): Model configuration. + layer_idx (int, optional): Index in transformer stack. Defaults to 0. + """ + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads + + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.sequence_parallel = config.sequence_parallel + self.attention_bias = config.attention_bias + + self.sequence_parallel = config.sequence_parallel + self.fuse_attention_qkv = config.fuse_attention_qkv + + self.scaling = self.head_dim**-0.5 + self.attn_implementation = config._attn_implementation + + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + assert ( + self.num_key_value_heads % config.tensor_parallel_degree == 0 + ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + + kv_hidden_size = self.config.num_key_value_heads * self.head_dim + q_hidden_size = self.num_attention_heads * self.head_dim + + self.q_proj = GeneralLinear.create( + self.hidden_size, + q_hidden_size, + has_bias=self.attention_bias, + config=config, + tp_plan="colwise", + ) + self.k_proj = GeneralLinear.create( + self.hidden_size, + kv_hidden_size, + has_bias=self.attention_bias, + config=config, + tp_plan="colwise", + ) + self.v_proj = GeneralLinear.create( + self.hidden_size, + kv_hidden_size, + has_bias=self.attention_bias, + config=config, + tp_plan="colwise", + ) + self.o_proj = GeneralLinear.create( + q_hidden_size, + self.hidden_size, + has_bias=self.attention_bias, + config=config, + tp_plan="rowwise", + ) + + self.sinks = paddle.create_parameter( + shape=[self.num_heads], dtype="bfloat16", default_initializer=paddle.nn.initializer.Uniform() + ) + + def forward( + self, + hidden_states, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + attn_mask_start_row_indices: Optional[paddle.Tensor] = None, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + position_embedding: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + batch_size: Optional[int] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Compute attention outputs. + + Args: + hidden_states (paddle.Tensor): Input tensor [bsz, seq_len, hidden_size] + past_key_value (Optional[Tuple[paddle.Tensor, paddle.Tensor]]): Cached key/value states + attention_mask (Optional[paddle.Tensor]): Attention mask tensor + attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length attention indices + position_ids (Optional[paddle.Tensor]): Position indices for RoPE + output_attentions (bool): Return attention weights if True + use_cache (bool): Cache key/value states if True + + Returns: + Tuple containing: + - attention_output: [bsz, seq_len, hidden_size] + - attention_weights: Optional attention probabilities + - updated_key_value_cache: Optional updated cache + """ + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.sequence_parallel: + target_query_shape = [batch_size, -1, self.num_heads, self.head_dim] + target_key_value_shape = [batch_size, -1, self.num_key_value_heads, self.head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + query_states = query_states.reshape(target_query_shape) + key_states = key_states.reshape(target_key_value_shape) + value_states = value_states.reshape(target_key_value_shape) + + attention_interface = eager_attention_forward + cos, sin = position_embedding + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + attn_output, attn_weights = attention_interface( + self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + attn_mask_start_row_indices=attn_mask_start_row_indices, + dropout=self.config.get("attention_dropout", 0.0) if self.training else 0.0, + scaling=self.scaling, + ) + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + if self.config.sequence_parallel: + attn_output = attn_output.reshape([-1, attn_output.shape[-1]]) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + +class GptOssDecoderLayer(nn.Layer): + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.self_attn = GptOssAttention(config, layer_idx) + self.mlp = GptOssMLP(config) + self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if config.sequence_parallel: + self.post_attention_layernorm.enable_sequence_parallel() + if not hasattr(config, "disable_ffn_model_parallel"): + self.input_layernorm.enable_sequence_parallel() + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + position_embedding: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + attn_mask_start_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + attn_mask_start_row_indices=attn_mask_start_row_indices, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=use_cache, + position_embedding=position_embedding, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + hidden_states = residual + hidden_states + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + if output_router_logits: + outputs += (router_logits,) + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class GptOssPreTrainedModel(PretrainedModel): + config: GptOssConfig + config_class = GptOssConfig + base_model_prefix = "model" + keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"] + + @classmethod + def _get_name_mappings(cls, config: GptOssConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_proj.bias", None], + [f"layers.{layer_index}.self_attn.k_proj.bias", None], + [f"layers.{layer_index}.self_attn.v_proj.bias", None], + [f"layers.{layer_index}.self_attn.o_proj.bias", None], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + [f"layers.{layer_index}.self_attn.q_norm.weight"], + [f"layers.{layer_index}.self_attn.k_norm.weight"], + [f"layers.{layer_index}.self_attn.sinks"], + ] + model_mappings.extend(layer_mappings) + + for expert_idx in range(config.num_experts): + expert_mappings = [ + [f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.up_proj.weight", None, "transpose"], + ] + model_mappings.extend(expert_mappings) + model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"]) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "GPTOssModel" + if "GptOssModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "model." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: GptOssConfig, is_split=True): + from ..conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers, num_experts): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + } + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + base_actions["layers.0.self_attn.sinks"] = partial(fn, is_column=False) + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers, config.num_experts) + + return mappings + + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: GptOssConfig, is_fuse=False): + # return parameter fuse utils + from ..conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = [ + ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ), + ] + + fuse_gate_up_keys = ( + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.up_proj.weight", + "layers.0.mlp.gate_up_fused_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = partial(fn, split_nums=2) + return final_actions + + +def _make_sliding_window_mask(input_shape, past_key_values_length=0, window_size=5): + """ + Generate a sliding window mask that restricts each position to only attend to historical positions within the window. + Format: [bsz, 1, tgt_seq_len, src_seq_len], where True indicates allowed attention and False indicates masking. + """ + batch_size, seq_length = input_shape + # Total sequence length = historical sequence length + current sequence length (for generating complete mask) + total_length = past_key_values_length + seq_length + + # Initialize mask with all False values + mask = paddle.zeros((seq_length, total_length), dtype=paddle.bool) + + for i in range(seq_length): + # Absolute position of current location in the total sequence (including historical sequence) + current_pos = past_key_values_length + i + # Window start position: max(0, current position - window size + 1) + start = max(0, current_pos - window_size + 1) + # Window end position: current position (causal mask restriction, cannot exceed self) + end = current_pos + 1 # 切片是左闭右开,所以+1 + # Mark window range as True (allow attention) + mask[i, start:end] = True + + # Expand dimensions to [bsz, 1, tgt_seq_len, src_seq_len] + mask = mask.unsqueeze(0).unsqueeze(0) + # Copy to each sample in batch_size + mask = paddle.tile(mask, repeat_times=[batch_size, 1, 1, 1]) + return mask + + +def _prepare_decoder_attention_mask( + attention_mask, input_shape, past_key_values_length, dtype, sliding_window_size=None # 新增:滑动窗口大小,None表示不启用 +): + # Step 1: Process input mask to generate basic expanded mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # When not generating in single step, need to combine causal mask and sliding window mask + if input_shape[-1] > 1: + # Generate basic causal mask (prevent future information leakage) + causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + # Generate sliding window mask (limit historical attention range) + if sliding_window_size is not None and sliding_window_size > 0: + window_mask = _make_sliding_window_mask( + input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size + ) + # Take intersection of sliding window mask and causal mask (satisfy both restrictions) + combined_attention_mask = causal_mask & window_mask + else: + combined_attention_mask = causal_mask # Use causal mask directly when sliding window is disabled + + # Combine with user-provided mask (e.g., padding mask) + if get_env_device() in ["npu", "mlu", "intel_hpu"]: + expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool") + else: + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # 4D mask is used directly + else: + expanded_attn_mask = attention_mask + else: + # When no input mask, generate causal mask + sliding window mask (if enabled) + causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + if sliding_window_size is not None and sliding_window_size > 0: + window_mask = _make_sliding_window_mask( + input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size + ) + expanded_attn_mask = causal_mask & window_mask + else: + expanded_attn_mask = causal_mask # Use causal mask directly when sliding window is disabled + + # Step 2: Convert boolean mask to numerical mask (adapt to different devices) + if get_env_device() in ["npu", "mlu", "intel_hpu"]: + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + elif get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y) + elif get_env_device() == "gcu": + min_val = paddle.finfo(dtype).min + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(min_val, dtype=dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min) + expanded_attn_mask = expanded_attn_mask.astype(dtype) + return expanded_attn_mask + + +@register_base_model +class GptOssModel(GptOssPreTrainedModel): + """ + Args: + config: GptOssConfig + """ + + _no_split_modules = ["GptOssDecoderLayer"] + + def __init__(self, config: GptOssConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.sequence_parallel = config.sequence_parallel + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + self.embed_tokens = GeneralEmbedding.create( + config=config, num_embeddings=config.vocab_size, embedding_dim=config.hidden_size + ) + + self.layers = nn.LayerList( + [ + GptOssDecoderLayer( + config=config, + layer_idx=layer_idx not in self.no_recompute_layers, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = GptOssRMSNorm(config.hidden_size) + self.rotary_emb = GptOssRotaryEmbedding(config=config) + if config.sequence_parallel: + self.norm.enable_sequence_parallel() + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + output_router_logits: bool, + past_key_value: Tensor, + use_cache: bool, + position_embedding: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + attn_mask_start_row_indices=None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + # output_router_logits, + past_key_value, + use_cache, + position_embedding, + attn_mask_start_row_indices, + # use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_start_row_indices=None, + **kwargs, + ) -> Union[Tuple, MoEModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + else: + cache_length = past_key_values[0][0].shape[1] + seq_length_with_past += cache_length + + if inputs_embeds is None: + # [bs, seq_len, dim] + inputs_embeds = self.embed_tokens(input_ids) + + if self.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + # embed positions + if attn_mask_start_row_indices is not None or get_use_casual_mask(): + attention_mask = None + else: + # [bs, seq_len] + attention_mask = ( + paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if attention_mask is None + else attention_mask + ) + causal_mask_mapping = {} + + # full_attention + causal_mask = _prepare_decoder_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=cache_length, + dtype=inputs_embeds.dtype, + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + causal_mask = None if is_casual_mask(causal_mask) else causal_mask + causal_mask_mapping["full_attention"] = causal_mask + + # sliding_attention + causal_mask = _prepare_decoder_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=cache_length, + dtype=inputs_embeds.dtype, + sliding_window_size=self.config.sliding_window, + ) + if self.config.use_flash_attention: + causal_mask = None if is_casual_mask(causal_mask) else causal_mask + causal_mask_mapping["sliding_attention"] = causal_mask + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + hidden_states = inputs_embeds + position_embedding = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if self.config.recompute and self.config.recompute_granularity == "full" and has_gradient: + layer_outputs = self.recompute_training_full( + layer_module=decoder_layer, + hidden_states=hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attn_mask_start_row_indices=attn_mask_start_row_indices, + position_ids=position_ids, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + past_key_value=past_key_value, + use_cache=use_cache, + position_embedding=position_embedding, + ) + else: + layer_outputs = decoder_layer( + hidden_states=hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attn_mask_start_row_indices=attn_mask_start_row_indices, + position_ids=position_ids, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + past_key_value=past_key_value, + use_cache=use_cache, + position_embedding=position_embedding, + ) + + # # NOTE: clear outdate cache after it has been used for memory saving + # past_key_value = past_key_values[idx] = None + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + + return MoEModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +def load_balancing_loss_func( + gate_logits: Union[paddle.Tensor, tuple[paddle.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[paddle.Tensor] = None, +) -> Union[paddle.Tensor, int]: + r""" + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`paddle.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = paddle.concat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = F.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = paddle.topk(routing_weights, top_k, dim=-1) + + expert_mask = F.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = paddle.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = paddle.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = paddle.sum(expert_mask.float() * expert_attention_mask, dim=0) / paddle.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = paddle.sum(routing_weights * router_per_expert_attention_mask, dim=0) / paddle.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = paddle.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class GptOssForCausalLM(GptOssPreTrainedModel): + enable_to_static_method = True + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: GptOssConfig): + super().__init__(config) + self.config = config + self.model = GptOssModel(config) + self.lm_head = GeneralLMHead(config) + self.criterion = CriterionLayer(config) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + if config.sliding_window: + self.config.sliding_window = False + logger.warning("We do not support sliding window attention for now.") + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=False, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs, + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + if isinstance(outputs, MoECausalLMOutputWithPast) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + if not is_encoder_decoder and "attention_mask" in model_kwargs: + # TODO: support attention mask for other models + attention_mask = model_kwargs["attention_mask"] + if len(attention_mask.shape) == 2: + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], + axis=-1, + ) + elif len(attention_mask.shape) == 4: + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([*attention_mask.shape[:3], 1], dtype=attention_mask.dtype)], + axis=-1, + )[:, :, -1:, :] + return model_kwargs + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, + logits_to_keep: Union[int, paddle.Tensor] = 0, + ): + return_dict = True + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if attn_mask_startend_row_indices is not None and attention_mask is not None: + logger.warning( + "You have provided both attn_mask_startend_row_indices and attention_mask. " + "The attn_mask_startend_row_indices will be used." + ) + attention_mask = None + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, # [bs, seq_len] + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + hidden_states = outputs[0] # [bs, seq_len, dim] + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is together with ParallelCrossEntropy + tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 + if labels is not None and self.config.use_fused_linear_cross_entropy: + from paddlenlp_kernel.triton.cut_cross_entropy import linear_cross_entropy + + assert ( + self.config.tensor_parallel_degree <= 1 + ), "The argument `use_fused_linear_cross_entropy` is imcompatiable with tensor parallel " + # todo :hidden_states[:, slice_indices, :] + masked_lm_loss = linear_cross_entropy(hidden_states, self.lm_head.weight, targets=labels) + binary_sequence = paddle.where( + masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) + ) + count = paddle.sum(binary_sequence) + if count == 0: + loss = paddle.sum(masked_lm_loss * binary_sequence) + else: + loss = paddle.sum(masked_lm_loss * binary_sequence) / count + logits = None + else: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :], tensor_parallel_output=tensor_parallel_output) + loss = None + if labels is not None: + loss, _ = self.criterion(logits, labels) + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoECausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["GptOssForCausalLM", "GptOssModel", "GptOssPreTrainedModel"]