diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index a70bcee195..335da8ba1c 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -1087,7 +1087,10 @@ ) from .models.phi4_multimodal import ( Phi4MultimodalFeatureExtractor, + Phi4MultimodalForCausalLM, Phi4MultimodalImageProcessorFast, + Phi4MultimodalModel, + Phi4MultimodalPreTrainedModel, Phi4MultimodalProcessor, ) from .models.phimoe import PhimoeForCausalLM, PhimoeForSequenceClassification, PhimoeModel, PhimoePreTrainedModel diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index 80f81d9d14..e22d7b9bfb 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -201,6 +201,7 @@ ("perceiver", "PerceiverConfig"), ("phi", "PhiConfig"), ("phi3", "Phi3Config"), + ("phi4_multimodal", "Phi4MultimodalConfig"), ("phimoe", "PhimoeConfig"), ("pix2struct", "Pix2StructConfig"), ("pixtral", "PixtralVisionConfig"), @@ -473,6 +474,7 @@ ("persimmon", "Persimmon"), ("phi", "Phi"), ("phi3", "Phi3"), + ("phi4_multimodal", "Phi4Multimodal"), ("phimoe", "Phimoe"), ("pegasus", "Pegasus"), ("pegasus_x", "PEGASUS-X"), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index 57a3e36cb9..7d15c3a46d 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -185,6 +185,7 @@ ("persimmon", "PersimmonModel"), ("phi", "PhiModel"), ("phi3", "Phi3Model"), + ("phi4_multimodal", "Phi4MultimodalModel"), ("pixtral", "PixtralVisionModel"), ("plbart", "PLBartModel"), ("poolformer", "PoolFormerModel"), @@ -480,6 +481,7 @@ ("persimmon", "PersimmonForCausalLM"), ("phi", "PhiForCausalLM"), ("phi3", "Phi3ForCausalLM"), + ("phi4_multimodal", "Phi4MultimodalForCausalLM"), ("pegasus", "PegasusForCausalLM"), ("plbart", "PLBartForCausalLM"), ("prophetnet", "ProphetNetForCausalLM"), diff --git a/mindone/transformers/models/phi4_multimodal/__init__.py b/mindone/transformers/models/phi4_multimodal/__init__.py index 3b4d98f65f..2e8470ef9c 100644 --- a/mindone/transformers/models/phi4_multimodal/__init__.py +++ b/mindone/transformers/models/phi4_multimodal/__init__.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .modeling_phi4_multimodal import * from .processing_phi4_multimodal import Phi4MMAudioFeatureExtractor as Phi4MultimodalFeatureExtractor from .processing_phi4_multimodal import Phi4MMImageProcessor as Phi4MultimodalImageProcessorFast from .processing_phi4_multimodal import Phi4MMProcessor as Phi4MultimodalProcessor diff --git a/mindone/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/mindone/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 0e8a92ee0e..c82e156f64 100644 --- a/mindone/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/mindone/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -6,6 +6,9 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. # +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,56 +22,638 @@ # limitations under the License. import math -from typing import Callable, Optional +import warnings +from typing import Callable, Optional, Union import numpy as np from transformers.models.phi4_multimodal.configuration_phi4_multimodal import ( Phi4MultimodalAudioConfig, Phi4MultimodalConfig, + Phi4MultimodalVisionConfig, ) -from transformers.utils import auto_docstring -import mindspore as ms -import mindspore.mint.nn.functional as F -from mindspore import Parameter, mint, nn +import mindspore from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils.generic import TransformersKwargs +from ...utils import can_return_tuple +from ...utils.generic import TransformersKwargs, check_model_inputs + + +class Phi4MultimodalVisionMLP(mindspore.nn.Cell): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = mindspore.mint.nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = mindspore.mint.nn.Linear(config.intermediate_size, config.hidden_size) + + def construct(self, hidden_states: mindspore.Tensor) -> mindspore.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states def simple_eager_attention_forward( - module: nn.Cell, - query_states: ms.Tensor, - key_states: ms.Tensor, - value_states: ms.Tensor, - attention_mask: Optional[ms.Tensor], + module: mindspore.nn.Cell, + query_states: mindspore.Tensor, + key_states: mindspore.Tensor, + value_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): - attn_weights = mint.matmul(query_states, key_states.transpose(2, 3)) * scaling + attn_weights = mindspore.mint.matmul(query_states, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query_states.dtype) - attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = mint.matmul(attn_weights, value_states) + attn_weights = mindspore.mint.nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to( + query_states.dtype + ) + attn_weights = mindspore.mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mindspore.mint.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -class Phi4MultimodalAudioMLP(nn.Cell): +class Phi4MultimodalVisionAttention(mindspore.nn.Cell): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = True + self.attention_dropout = config.attention_dropout + + self.k_proj = mindspore.mint.nn.Linear(config.hidden_size, config.hidden_size) + self.v_proj = mindspore.mint.nn.Linear(config.hidden_size, config.hidden_size) + self.q_proj = mindspore.mint.nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = mindspore.mint.nn.Linear(config.hidden_size, config.hidden_size) + + def construct( + self, + hidden_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[mindspore.Tensor, Optional[mindspore.Tensor]]: + """Input shape: Batch x Time x Channel""" + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1) + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalVisionEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = mindspore.mint.nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Phi4MultimodalVisionAttention(config) + self.layer_norm2 = mindspore.mint.nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) + + def construct( + self, + hidden_states: mindspore.Tensor, + attention_mask: mindspore.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> mindspore.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Phi4MultimodalVisionEncoder(mindspore.nn.Cell): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Phi4MultimodalVisionEncoderLayer`]. + + Args: + config: Phi4MultimodalVisionConfig + """ + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.layers = mindspore.nn.CellList( + [Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + # Ignore copy + + def construct( + self, + inputs_embeds, + attention_mask: Optional[mindspore.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + value_l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * value_l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: mindspore.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> mindspore.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `mindspore.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = 0.5, 0.5 # _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with mindspore._no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with mindspore._no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): + config: Phi4MultimodalVisionConfig + base_model_prefix = "phi4_vision" + supports_gradient_checkpointing = True + + _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + _can_record_outputs = { + "hidden_states": Phi4MultimodalVisionEncoderLayer, + "attentions": Phi4MultimodalVisionAttention, + } + + def _init_weights(self, module): + """Initialize the weights""" + pass + + +class Phi4MultimodalVisionEmbeddings(mindspore.nn.Cell): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.num_patches_per_side = config.image_size // self.patch_size + + self.patch_embedding = mindspore.mint.nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.position_embedding = mindspore.mint.nn.Embedding(self.num_patches_per_side**2, config.hidden_size) + + def interpolate_pos_encoding(self, embeddings: mindspore.Tensor, height: int, width: int) -> mindspore.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = mindspore.mint.nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def construct(self, pixel_values: mindspore.Tensor, patch_attention_mask: mindspore.Tensor) -> mindspore.Tensor: + batch_size = pixel_values.shape[0] + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.shape[2], pixel_values.shape[3] + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = mindspore.mint.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = mindspore.ops.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = mindspore.mint.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = mindspore.mint.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = mindspore.ops.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = mindspore.ops.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Phi4MultimodalVisionMultiheadAttentionPoolingHead(mindspore.nn.Cell): + """Multihead Attention Pooling.""" + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + + self.probe = mindspore.Parameter(mindspore.mint.randn(1, 1, config.hidden_size)) + self.attention = mindspore.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = mindspore.mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) + + def construct(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): + config: Phi4MultimodalVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = Phi4MultimodalVisionEmbeddings(config) + self.encoder = Phi4MultimodalVisionEncoder(config) + self.post_layernorm = mindspore.mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> mindspore.nn.Cell: + return self.embeddings.patch_embedding + + @check_model_inputs + def construct( + self, + pixel_values, + patch_attention_mask: Optional[mindspore.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + batch_size = pixel_values.shape[0] + if patch_attention_mask is None: + patch_attention_mask = mindspore.mint.ones( + size=( + batch_size, + pixel_values.shape[2] // self.config.patch_size, + pixel_values.shape[3] // self.config.patch_size, + ), + dtype=mindspore.bool_, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not mindspore.mint.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if self.config._attn_implementation != "flash_attention_2" + else patch_attention_mask + ) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + +class Phi4MultimodalImageEmbedding(mindspore.nn.Cell): + """Image embedding.""" + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.vision_config.feature_layer + self.crop_size = config.vision_config.crop_size + self.image_dim_out = config.vision_config.hidden_size + + n_patches = config.vision_config.image_size // config.vision_config.patch_size + if n_patches % 2 != 0: + self.img_processor_padding = mindspore.mint.nn.ReflectionPad2d((0, 1, 0, 1)) + n_patches += 1 + self.num_img_tokens = (n_patches // 2) ** 2 + + self.drop = mindspore.mint.nn.Dropout(config.embd_pdrop) + self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config) + self.image_token_compression = mindspore.mint.nn.AvgPool2d(kernel_size=2, stride=2) + self.img_projection_up = mindspore.mint.nn.Linear(self.image_dim_out, config.hidden_size) + self.img_projection_down = mindspore.mint.nn.Linear(config.hidden_size, config.hidden_size) + self.global_img_feature_extensor = mindspore.Parameter(mindspore.mint.zeros([1, 1, self.image_dim_out])) + self.sub_img_feature_extensor = mindspore.Parameter(mindspore.mint.zeros([1, 1, 1, self.image_dim_out])) + + def get_img_features(self, img_embeds: mindspore.Tensor, attention_mask=None) -> mindspore.Tensor: + img_processor_output = self.img_processor( + img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True + ) + img_feature = img_processor_output.hidden_states[self.layer_idx] + + patch_feature = img_feature + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.shape[1])) + patch_feature = patch_feature.view(-1, width, width, patch_feature.shape[-1]) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + if getattr(self, "img_processor_padding", None) is not None: + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view(-1, patch_feature.shape[1] * patch_feature.shape[2], patch_feature.shape[-1]) + return patch_feature + + def construct( + self, + input_ids: mindspore.Tensor, + inputs_embeds: mindspore.Tensor, + image_pixel_values: mindspore.Tensor, + image_sizes: Optional[mindspore.Tensor] = None, + image_attention_mask: Optional[mindspore.Tensor] = None, + ) -> mindspore.Tensor: + image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype) + + target_dtype = self.img_projection_up.bias.dtype + + batch_size = image_pixel_values.shape[0] + + img_features = self.get_img_features( + image_pixel_values.flatten(0, 1), + attention_mask=image_attention_mask.flatten(0, 1).to( + dtype=bool, + ), + ) + base_feat_size = int(np.sqrt(img_features.shape[1])) + img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out) + image_sizes = image_sizes.view(-1, 2) + + output_imgs = [] + for idx in range(batch_size): + height, width = image_sizes[idx] + height_ratio = height // self.crop_size + width_ratio = width // self.crop_size + area_ratio = height_ratio * width_ratio + + global_img = img_features[idx, :1] + global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous() + temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1) + global_img = mindspore.mint.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + sub_img = img_features[idx, 1:] + sub_img = sub_img[:area_ratio] + sub_img = ( + sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out) + .contiguous() + ) + + if image_attention_mask is not None: + reshaped_image_attention_mask = ( + image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] + .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1) + else: + temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1) + + sub_img = mindspore.mint.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + # Merge global and sub + output_imgs.append(mindspore.mint.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1)) + + img_set_tensor = [] + for output_img in output_imgs: + output_img = output_img.to(dtype=target_dtype) + img_feature_proj = self.img_projection_up(output_img) + img_feature_proj = mindspore.mint.nn.functional.gelu(img_feature_proj) + img_feature_proj = self.img_projection_down(img_feature_proj) + img_set_tensor.append(img_feature_proj) + + merged_img_set_tensor = mindspore.mint.cat(img_set_tensor, dim=1).squeeze(0) + merged_img_set_tensor = merged_img_set_tensor.to( + dtype=inputs_embeds.dtype, + ) + + with mindspore._no_grad(): + positions_tuple = mindspore.mint.nonzero( + input_ids == self.config.vision_config.image_token_id, as_tuple=True + ) + + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + image_embeds = inputs_embeds.index_put(indices=positions_tuple, values=merged_img_set_tensor, accumulate=False) + + image_embeds = self.drop(image_embeds) + + return image_embeds + + +# AUDIO +class Phi4MultimodalAudioMLP(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() - self.layer_norm = mint.nn.LayerNorm(config.hidden_size) + self.layer_norm = mindspore.mint.nn.LayerNorm(config.hidden_size) self.act_fn = ACT2FN[config.activation] - self.gate_up_proj = mint.nn.Linear(config.hidden_size, config.intermediate_size * 2) - self.down_proj = mint.nn.Linear(config.intermediate_size, config.hidden_size) - self.dropout = nn.Dropout(config.dropout_rate) + self.gate_up_proj = mindspore.mint.nn.Linear(config.hidden_size, config.intermediate_size * 2) + self.down_proj = mindspore.mint.nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = mindspore.mint.nn.Dropout(config.dropout_rate) def construct(self, hidden_states): hidden_states = self.layer_norm(hidden_states) @@ -82,7 +667,7 @@ def construct(self, hidden_states): return out -class Phi4MultimodalAudioAttention(nn.Cell): +class Phi4MultimodalAudioAttention(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config @@ -91,15 +676,23 @@ def __init__(self, config: Phi4MultimodalAudioConfig): self.attention_dropout = config.dropout_rate self.is_causal = True - self.q_proj = mint.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) - self.k_proj = mint.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) - self.v_proj = mint.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) - self.o_proj = mint.nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + self.q_proj = mindspore.mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=True + ) + self.k_proj = mindspore.mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=True + ) + self.v_proj = mindspore.mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=True + ) + self.o_proj = mindspore.mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=True + ) def construct( self, - hidden_states: ms.Tensor, - attention_mask: ms.Tensor, + hidden_states: mindspore.Tensor, + attention_mask: mindspore.Tensor, **kwargs, ): input_shape = hidden_states.shape[:-1] @@ -129,10 +722,10 @@ def construct( return attn_output -class Phi4MultimodalAudioDepthWiseSeparableConv1d(nn.Cell): +class Phi4MultimodalAudioDepthWiseSeparableConv1d(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): super().__init__() - self.dw_conv = mint.nn.Conv1d( + self.dw_conv = mindspore.mint.nn.Conv1d( config.hidden_size, config.hidden_size * config.depthwise_multiplier, config.kernel_size, @@ -140,7 +733,7 @@ def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): padding=padding, groups=config.hidden_size, ) - self.pw_conv = mint.nn.Conv1d( + self.pw_conv = mindspore.mint.nn.Conv1d( config.hidden_size * config.depthwise_multiplier, config.depthwise_separable_out_channel, 1, 1, 0 ) @@ -148,16 +741,18 @@ def construct(self, hidden_states): return self.pw_conv(self.dw_conv(hidden_states)) -class Phi4MultimodalAudioGluPointWiseConv(nn.Cell): +class Phi4MultimodalAudioGluPointWiseConv(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.output_dim = config.ext_pw_out_channel - self.ext_pw_conv_1d = mint.nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) + self.ext_pw_conv_1d = mindspore.mint.nn.Conv1d( + config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1 + ) self.glu_act = ACT2FN[config.conv_glu_type] - self.b1 = Parameter(mint.zeros((1, config.ext_pw_out_channel, 1))) - self.b2 = Parameter(mint.zeros((1, config.ext_pw_out_channel, 1))) + self.b1 = mindspore.Parameter(mindspore.mint.zeros((1, config.ext_pw_out_channel, 1))) + self.b2 = mindspore.Parameter(mindspore.mint.zeros((1, config.ext_pw_out_channel, 1))) def construct(self, hidden_states): # we assume the input always has the #channel (#dim) in the last dimension of the @@ -169,20 +764,22 @@ def construct(self, hidden_states): return out.permute([0, 2, 1]) -class Phi4MultimodalAudioConvModule(nn.Cell): +class Phi4MultimodalAudioConvModule(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.kernel_size = config.kernel_size - self.layer_norm = mint.nn.LayerNorm(config.hidden_size) + self.layer_norm = mindspore.mint.nn.LayerNorm(config.hidden_size) self.glu = Phi4MultimodalAudioGluPointWiseConv(config) self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeparableConv1d(config, padding=config.kernel_size - 1) self.act = ACT2FN[config.conv_activation] - self.ext_pw_conv_1d = mint.nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) - self.dropout = nn.Dropout(config.dropout_rate) + self.ext_pw_conv_1d = mindspore.mint.nn.Conv1d( + config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1 + ) + self.dropout = mindspore.mint.nn.Dropout(config.dropout_rate) - def construct(self, hidden_states: ms.Tensor): + def construct(self, hidden_states: mindspore.Tensor): hidden_states = self.glu(self.layer_norm(hidden_states)) hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1])) @@ -195,7 +792,7 @@ def construct(self, hidden_states: ms.Tensor): return out -class Phi4MultimodalAudioConformerEncoderLayer(nn.Cell): +class Phi4MultimodalAudioConformerEncoderLayer(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() @@ -203,13 +800,13 @@ def __init__(self, config: Phi4MultimodalAudioConfig): self.self_attn = Phi4MultimodalAudioAttention(config) self.conv = Phi4MultimodalAudioConvModule(config) self.feed_forward_out = Phi4MultimodalAudioMLP(config) - self.layer_norm_att = mint.nn.LayerNorm(config.hidden_size) - self.layer_norm = mint.nn.LayerNorm(config.hidden_size) + self.layer_norm_att = mindspore.mint.nn.LayerNorm(config.hidden_size) + self.layer_norm = mindspore.mint.nn.LayerNorm(config.hidden_size) def construct( self, - hidden_states: ms.Tensor, - attention_mask: ms.Tensor, + hidden_states: mindspore.Tensor, + attention_mask: mindspore.Tensor, ): residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) hidden_states = self.layer_norm_att(residual) @@ -223,7 +820,7 @@ def construct( return out -class Phi4MultimodalAudioNemoConvSubsampling(nn.Cell): +class Phi4MultimodalAudioNemoConvSubsampling(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.subsampling_factor = config.time_reduction @@ -232,31 +829,33 @@ def __init__(self, config: Phi4MultimodalAudioConfig): conv_channels = config.nemo_conv_channels layers = [ - mint.nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), + mindspore.mint.nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), self.act_fn, ] for _ in range(self.sampling_num - 1): layers.extend( [ - mint.nn.Conv2d( + mindspore.mint.nn.Conv2d( conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels ), - mint.nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), + mindspore.mint.nn.Conv2d( + conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1 + ), self.act_fn, ] ) # Aggregate the layers - self.conv = ms.nn.SequentialCell(*layers) - self.out = mint.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) + self.conv = mindspore.nn.SequentialCell(*layers) + self.out = mindspore.mint.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) - def construct(self, hidden_states: ms.Tensor, mask: Optional[ms.Tensor]): + def construct(self, hidden_states: mindspore.Tensor, mask: Optional[mindspore.Tensor]): # Unsqueeze Channel Axis hidden_states = hidden_states.unsqueeze(1) hidden_states = self.conv(hidden_states) # Flatten Channel and Frequency Axes - b, _, t, _ = hidden_states.size() + b, _, t, _ = hidden_states.shape hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1)) if mask is None: @@ -264,13 +863,16 @@ def construct(self, hidden_states: ms.Tensor, mask: Optional[ms.Tensor]): max_audio_length = hidden_states.shape[1] feature_lens = mask.sum(1) - padding_length = mint.ceil(feature_lens / self.subsampling_factor) - arange_ = mint.arange(0, max_audio_length) - pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) + padding_length = mindspore.mint.ceil(feature_lens / self.subsampling_factor) + arange_ = mindspore.mint.arange( + 0, + max_audio_length, + ) + pad_mask = arange_.expand(padding_length.shape[0], -1) < padding_length.unsqueeze(1) return hidden_states, pad_mask.unsqueeze(1) -class Phi4MultimodalAudioRelativeAttentionBias(nn.Cell): +class Phi4MultimodalAudioRelativeAttentionBias(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() @@ -279,13 +881,13 @@ def __init__(self, config: Phi4MultimodalAudioConfig): self.num_buckets = self.max_distance if not config.bias_symmetric: self.num_buckets *= 2 - self.bias_values = mint.nn.Embedding(self.num_buckets, config.num_attention_heads) + self.bias_values = mindspore.mint.nn.Embedding(self.num_buckets, config.num_attention_heads) def construct(self, x): # instantiate bias compatible with shape of x max_pos = x.shape[1] - context_position = mint.arange(max_pos, dtype=ms.int64)[:, None] - memory_position = mint.arange(max_pos, dtype=ms.int64)[None, :] + context_position = mindspore.mint.arange(max_pos, dtype=mindspore.int64)[:, None] + memory_position = mindspore.mint.arange(max_pos, dtype=mindspore.int64)[None, :] relative_position = memory_position - context_position # clipping to a maximum distance using ops that play well with ONNX export relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) @@ -303,17 +905,16 @@ def construct(self, x): return att_bias -class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Cell): +class Phi4MultimodalAudioMeanVarianceNormLayer(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() - self.register_buffer("global_mean", mint.zeros(config.input_size)) - self.register_buffer("global_invstd", mint.ones(config.input_size)) + self.register_buffer("global_mean", mindspore.mint.zeros(config.input_size)) + self.register_buffer("global_invstd", mindspore.mint.ones(config.input_size)) def construct(self, x): return (x - self.global_mean) * self.global_invstd -@auto_docstring class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): config: Phi4MultimodalAudioConfig supports_gradient_checkpointing = True @@ -339,7 +940,9 @@ def unfold_tensor(tensor, max_seq_len): _, _, D = tensor.shape tensor = tensor.transpose(-1, -2) # N x D x 1 x T => N x (D x max_seq_len) x T' - tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len)) + tensor = mindspore.mint.nn.functional.unfold( + tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len) + ) new_bsz, _, slen = tensor.shape tensor = tensor.view(new_bsz, -1, max_seq_len, slen) @@ -357,18 +960,18 @@ def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): left_window (int): how many left chunks can be seen right_window (int): how many right chunks can be seen. It is used for chunk overlap model. Returns: - mask (ms.Tensor): a mask tensor for streaming model + mask (mindspore.tensor): a mask tensor for streaming model """ - chunk_start_idx = ms.Tensor(chunk_start_idx).long() - start_pad = mint.nn.functional.pad( + chunk_start_idx = mindspore.Tensor(chunk_start_idx).long() + start_pad = mindspore.mint.nn.functional.pad( chunk_start_idx, (1, 0) ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] - end_pad = mint.nn.functional.pad( + end_pad = mindspore.mint.nn.functional.pad( chunk_start_idx, (0, 1), value=x_len ) # append x_len to the end, so it becomes [0,18,36,48, x_len] - seq_range = mint.arange(0, x_len).unsqueeze(-1) + seq_range = mindspore.mint.arange(0, x_len).unsqueeze(-1) idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] - seq_range_expand = mint.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + seq_range_expand = mindspore.mint.arange(0, x_len).unsqueeze(0).expand(x_len, -1) idx_left = idx - left_window idx_left[idx_left < 0] = 0 boundary_left = start_pad[idx_left] @@ -388,7 +991,7 @@ def __init__(self, config: Phi4MultimodalAudioConfig): self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) - self.encoders = nn.CellList( + self.encoders = mindspore.nn.CellList( [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] ) self.gradient_checkpointing = False @@ -416,7 +1019,7 @@ def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): ) return enc_streaming_mask - def construct_embeddings(self, hidden_states, masks): + def forward_embeddings(self, hidden_states, masks): """Forwarding the inputs through the top embedding layers""" seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) if seq_len <= 0: @@ -446,17 +1049,23 @@ def calculate_hs_mask(self, hidden_states, mask): enc_streaming_mask = self._streaming_mask( max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk ) + enc_streaming_mask = enc_streaming_mask if mask is None: return enc_streaming_mask feature_lens = mask.sum(1) padding_length = feature_lens - pad_mask = mint.arange(0, max_audio_length).expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) + pad_mask = mindspore.mint.arange( + 0, + max_audio_length, + ).expand( + padding_length.shape[0], -1 + ) < padding_length.unsqueeze(1) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask - def construct(self, hidden_states: ms.Tensor, mask: Optional[ms.Tensor]): + def construct(self, hidden_states: mindspore.Tensor, mask: Optional[mindspore.Tensor]): hidden_states = self.encoder_embedding(hidden_states) hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) @@ -472,7 +1081,9 @@ def construct(self, hidden_states: ms.Tensor, mask: Optional[ms.Tensor]): else: chunk_pad_size = 0 if chunk_pad_size > 0: - hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0) + hidden_states_pad = mindspore.mint.nn.functional.pad( + hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0 + ) hidden_states = hidden_states_pad hidden_states = unfold_tensor(hidden_states, max_seq_len) @@ -480,7 +1091,7 @@ def construct(self, hidden_states: ms.Tensor, mask: Optional[ms.Tensor]): if mask is not None: # revise hs_mask here because the previous calculated hs_mask did not consider extra pad subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] - extra_padded_subsamlped_pad_mask = F.pad( + extra_padded_subsamlped_pad_mask = mindspore.mint.nn.functional.pad( subsampled_pad_mask, (0, chunk_pad_size), "constant", False ) # extra padding to the pad mask extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() @@ -508,34 +1119,36 @@ def construct(self, hidden_states: ms.Tensor, mask: Optional[ms.Tensor]): return hidden_states -class Phi4MultimodalAudioEmbedding(nn.Cell): +class Phi4MultimodalAudioEmbedding(mindspore.nn.Cell): def __init__(self, config: Phi4MultimodalConfig): super().__init__() self.config = config self.layer_idx = config.audio_config.feature_layer - self.drop = nn.Dropout(config.embd_pdrop) + self.drop = mindspore.mint.nn.Dropout(config.embd_pdrop) self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config) - self.up_proj_for_speech = mint.nn.Linear( + self.up_proj_for_speech = mindspore.mint.nn.Linear( config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size ) - self.down_proj_for_speech = mint.nn.Linear(config.hidden_size, config.hidden_size) - self.up_proj_for_vision_speech = mint.nn.Linear( + self.down_proj_for_speech = mindspore.mint.nn.Linear(config.hidden_size, config.hidden_size) + self.up_proj_for_vision_speech = mindspore.mint.nn.Linear( config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size ) - self.down_proj_for_vision_speech = mint.nn.Linear(config.hidden_size, config.hidden_size) + self.down_proj_for_vision_speech = mindspore.mint.nn.Linear(config.hidden_size, config.hidden_size) def construct( self, - input_ids: ms.Tensor, - inputs_embeds: ms.Tensor, - audio_input_features: ms.Tensor, + input_ids: mindspore.Tensor, + inputs_embeds: mindspore.Tensor, + audio_input_features: mindspore.Tensor, audio_embed_sizes=None, audio_attention_mask=None, audio_projection_mode="speech", - ) -> ms.Tensor: - with ms._no_grad(): - positions_tuple = mint.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True) + ) -> mindspore.Tensor: + with mindspore._no_grad(): + positions_tuple = mindspore.mint.nonzero( + input_ids == self.config.audio_config.audio_token_id, as_tuple=True + ) up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech down_proj = self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech @@ -546,16 +1159,664 @@ def construct( audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) - audio_encoder_hidden_states = mint.nn.functional.gelu(audio_encoder_hidden_states) + audio_encoder_hidden_states = mindspore.mint.nn.functional.gelu(audio_encoder_hidden_states) audio_embeds = down_proj(audio_encoder_hidden_states) - merged_audio_embeds = mint.cat( + merged_audio_embeds = mindspore.mint.cat( [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 ) - merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype) - + merged_audio_embeds = merged_audio_embeds.to( + dtype=inputs_embeds.dtype, + ) + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 audio_embeds = inputs_embeds.index_put(indices=positions_tuple, values=merged_audio_embeds, accumulate=False) audio_embeds = self.drop(audio_embeds) return audio_embeds + + +class Phi4MultimodalRMSNorm(mindspore.nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi4MultimodalRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = mindspore.Parameter(mindspore.mint.ones(hidden_size)) + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(mindspore.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mindspore.mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Phi4MultimodalMLP(mindspore.nn.Cell): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = mindspore.mint.nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = mindspore.mint.nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def construct(self, hidden_states: mindspore.Tensor) -> mindspore.Tensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mindspore.mint.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: mindspore.Tensor, n_rep: int) -> mindspore.Tensor: + """ + This is the equivalent of ms.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: mindspore.nn.Cell, + query: mindspore.Tensor, + key: mindspore.Tensor, + value: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mindspore.mint.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mindspore.mint.nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query.dtype) + attn_weights = mindspore.mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mindspore.mint.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`mindspore.tensor`): The query tensor. + k (`mindspore.tensor`): The key tensor. + cos (`mindspore.tensor`): The cosine part of the rotary embedding. + sin (`mindspore.tensor`): The sine part of the rotary embedding. + position_ids (`mindspore.tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(mindspore.tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = mindspore.mint.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = mindspore.mint.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class Phi4MultimodalAttention(mindspore.nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = mindspore.mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=False + ) + self.qkv_proj = mindspore.mint.nn.Linear(config.hidden_size, op_size, bias=False) + + def construct( + self, + hidden_states: mindspore.Tensor, + position_embeddings: tuple[mindspore.Tensor, mindspore.Tensor], + attention_mask: Optional[mindspore.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[mindspore.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[tuple[mindspore.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Phi4MultimodalConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx) + self.mlp = Phi4MultimodalMLP(config) + self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.resid_attn_dropout = mindspore.mint.nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = mindspore.mint.nn.Dropout(config.resid_pdrop) + + def construct( + self, + hidden_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[mindspore.Tensor] = None, + position_embeddings: Optional[ + tuple[mindspore.Tensor, mindspore.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[mindspore.Tensor, Optional[tuple[mindspore.Tensor, mindspore.Tensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + return hidden_states + + +class Phi4MultimodalFeatureEmbedding(mindspore.nn.Cell): + """Image-audio embedding.""" + + def __init__(self, config: Phi4MultimodalConfig) -> None: + super().__init__() + self.config = config + self.image_token_id = config.vision_config.image_token_id + self.audio_token_id = config.audio_config.audio_token_id + self.image_embed = Phi4MultimodalImageEmbedding(config) + self.audio_embed = Phi4MultimodalAudioEmbedding(config) + + def construct( + self, + input_ids: mindspore.Tensor, + inputs_embeds: mindspore.Tensor, + image_pixel_values: Optional[mindspore.Tensor] = None, + audio_input_features: Optional[mindspore.Tensor] = None, + image_sizes=None, + image_attention_mask=None, + audio_embed_sizes=None, + audio_attention_mask=None, + ) -> mindspore.Tensor: + with mindspore._no_grad(): + image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) + non_image_position_mask = ~image_position_mask + + image_embeds = None + audio_embeds = None + if image_pixel_values is not None and (input_ids == self.image_token_id).any(): + image_embeds = self.image_embed( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + ) + if audio_input_features is not None and (input_ids == self.audio_token_id).any(): + audio_projection_mode = "vision" if image_pixel_values is not None else "speech" + audio_embeds = self.audio_embed( + input_ids, + inputs_embeds, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + + # merge image and audio + if image_embeds is not None and audio_embeds is not None: + inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask + elif image_embeds is not None: + inputs_embeds = image_embeds + elif audio_embeds is not None: + inputs_embeds = audio_embeds + + return inputs_embeds + + +class Phi4MultimodalRotaryEmbedding(mindspore.nn.Cell): + inv_freq: mindspore.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @mindspore._no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def construct(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = mindspore.mint.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi4MultimodalPreTrainedModel(PreTrainedModel): + config: Phi4MultimodalConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi4MultimodalDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Phi4MultimodalDecoderLayer, + "attentions": Phi4MultimodalAttention, + } + _version = "0.0.5" + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Phi4MultimodalImageEmbedding): + module.global_img_feature_extensor.data.zero_() + module.sub_img_feature_extensor.data.zero_() + + +class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = mindspore.mint.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = mindspore.nn.CellList( + [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi4MultimodalRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + self.embed_dropout = mindspore.mint.nn.Dropout(config.embd_pdrop) + + self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def construct( + self, + input_ids: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + image_pixel_values: Optional[mindspore.Tensor] = None, + image_sizes: Optional[mindspore.Tensor] = None, + image_attention_mask=None, + audio_input_features: Optional[mindspore.Tensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[mindspore.Tensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + r""" + image_pixel_values (`mindspore.tensor`, *optional*): + If the input contains images, these correspond to the pixel values after transformations (as returned by + the Processor) + image_sizes (`mindspore.tensor`, *optional*): + If the input contains images, these correspond to size of each image. + image_attention_mask (`mindspore.tensor`, *optional*): + Attention mask for the images. + audio_input_features (`mindspore.tensor`, *optional*): + If the input contains audio samples, these correspond to the values after transformation (as returned by + the Processor). + audio_embed_sizes (`mindspore.tensor`, *optional*): + Size of the audio inputs. + audio_attention_mask (`mindspore.tensor, *optional*): + Attention mask for the audio inputs. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens_extend( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + audio_input_features=audio_input_features, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = mindspore.mint.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Phi4MultimodalModel(config) + self.vocab_size = config.vocab_size + self.lm_head = mindspore.mint.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + def construct( + self, + input_ids: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + image_pixel_values: Optional[mindspore.Tensor] = None, + image_sizes: Optional[mindspore.Tensor] = None, + image_attention_mask=None, + audio_input_features: Optional[mindspore.Tensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + labels: Optional[mindspore.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[mindspore.Tensor] = None, + logits_to_keep: Union[int, mindspore.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + image_pixel_values (`mindspore.tensor`, *optional*): + If the input contains images, these correspond to the pixel values after transformations (as returned by + the Processor) + image_sizes (`mindspore.tensor`, *optional*): + If the input contains images, these correspond to size of each image. + image_attention_mask (`mindspore.tensor`, *optional*): + Attention mask for the images. + audio_input_features (`mindspore.tensor`, *optional*): + If the input contains audio samples, these correspond to the values after transformation (as returned by + the Processor). + audio_embed_sizes (`mindspore.tensor`, *optional*): + Size of the audio inputs. + audio_attention_mask (`mindspore.tensor, *optional*): + Attention mask for the audio inputs. + labels (`mindspore.tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import Phi4MultimodalForCausalLM + >>> import mindspore as ms + >>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA") + >>> tokenizer = AutoTokenizer.from_pretrained("TBA") + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="np") + >>> # Generate + >>> generate_ids = model.generate(ms.tensor(inputs.input_ids), max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + image_pixel_values=None, + image_sizes=None, + image_attention_mask=None, + audio_input_features=None, + audio_embed_sizes=None, + audio_attention_mask=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=0, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +__all__ = [ + "Phi4MultimodalAudioPreTrainedModel", + "Phi4MultimodalAudioModel", + "Phi4MultimodalVisionPreTrainedModel", + "Phi4MultimodalVisionModel", + "Phi4MultimodalPreTrainedModel", + "Phi4MultimodalModel", + "Phi4MultimodalForCausalLM", +] diff --git a/tests/transformers_tests/models/phi4_multimodal/__init__.py b/tests/transformers_tests/models/phi4_multimodal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py b/tests/transformers_tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py new file mode 100644 index 0000000000..288cd78bfa --- /dev/null +++ b/tests/transformers_tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py @@ -0,0 +1,251 @@ +"""Adapted from https://github.com/huggingface/transformers/tree/main/tests//models/phi4_multimodal/test_modeling_phi4_multimodal.py.""" + +# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like +# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs, outputs_map]. +# +# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective +# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically +# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks. +# +# In cases where models have unique initialization procedures or require testing with specialized output formats, +# it is necessary to develop distinct, dedicated test cases. + +import inspect +import logging + +import numpy as np +import pytest +import torch +from transformers import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy, ids_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-3} +MODES = [1] + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Phi4MultimodalModelTester: + def __init__( + self, + batch_size=2, + seq_length=12, + image_seq_length=275, + audio_seq_length=8, + is_training=True, + num_hidden_layers=2, + vocab_size=49, + hidden_size=32, + intermediate_size=64, + num_attention_heads=8, + num_key_value_heads=4, + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + image_token_id=1, + audio_token_id=2, + image_size=16, + audio_size=12, + audio_config=Phi4MultimodalAudioConfig( + num_blocks=2, + hidden_size=32, + num_attention_heads=8, + intermediate_size=48, + depthwise_separable_out_channel=128, + nemo_conv_channels=128, + initializer_range=1e-5, + ), + vision_config=Phi4MultimodalVisionConfig( + num_hidden_layers=2, + hidden_size=32, + intermediate_size=64, + num_attention_heads=8, + crop_size=16, + initializer_range=1e-5, + ), + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.eos_token_id = eos_token_id + self.image_token_id = image_token_id + self.audio_token_id = audio_token_id + self.audio_config = audio_config + self.vision_config = vision_config + + self.is_training = is_training + self.batch_size = batch_size + self.seq_length = seq_length + image_seq_length + audio_seq_length + self.image_seq_length = image_seq_length + self.audio_seq_length = audio_seq_length + self.image_size = image_size + self.audio_size = audio_size + self.num_channels = 3 + + def get_config(self): + return Phi4MultimodalConfig( + num_hidden_layers=self.num_hidden_layers, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + vision_config=self.vision_config, + audio_config=self.audio_config, + ) + + # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs + def prepare_config_and_inputs(self): + input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size) + + # The shapes corresponds to the inputs for image of size 16x16 + image_pixel_values = floats_numpy([self.batch_size, 2, self.num_channels, self.image_size, self.image_size]) + image_attention_mask = np.ones((self.batch_size, 2, 1, 1)) + image_sizes = ms.tensor([[self.image_size, self.image_size]] * self.batch_size, dtype=ms.int64) + + # Feature sizes returned by an audio of size 10000 + audio_input_features = floats_numpy([self.batch_size, 61, 80]) + audio_embed_sizes = ms.tensor([self.audio_seq_length] * self.batch_size, dtype=ms.int64) + + input_ids[input_ids == self.pad_token_id] = self.pad_token_id + 1 # random value but not pad token + input_ids[-1, 0] = self.pad_token_id # mask the last text token + input_ids[:, -self.image_seq_length - self.audio_seq_length : -self.audio_seq_length] = self.image_token_id + input_ids[:, -self.audio_seq_length :] = self.audio_token_id + + attention_mask = np.ones_like(input_ids) + attention_mask[-1, 0] = 0 # mask the last text token + config = self.get_config() + + return ( + config, + input_ids, + attention_mask, + image_pixel_values, + image_attention_mask, + image_sizes, + audio_input_features, + audio_embed_sizes, + ) + + +model_tester = Phi4MultimodalModelTester() +( + config, + input_ids, + attention_mask, + image_pixel_values, + image_attention_mask, + image_sizes, + audio_input_features, + audio_embed_sizes, +) = model_tester.prepare_config_and_inputs() + + +PHI4_CASES = [ + [ + "Phi4MultimodalModel", + "transformers.Phi4MultimodalModel", + "mindone.transformers.Phi4MultimodalModel", + (config,), + {}, + (input_ids,), + { + "attention_mask": attention_mask, + }, + { + "last_hidden_state": 0, + }, + ], +] + + +# transformers need >= 4.41.2 +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in PHI4_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + # logger.info(f"ms:{ms_outputs}") + # logger.info(f"pt:{pt_outputs}" ) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + logger.info(f"Differences: {diffs}") + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type: {pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )