From e081a827f766931a0ea31694b758bf8ba0a72c1d Mon Sep 17 00:00:00 2001 From: eaidova Date: Fri, 7 Mar 2025 18:49:20 +0400 Subject: [PATCH 1/2] support mamba models --- optimum/exporters/openvino/model_configs.py | 130 ++++++++ optimum/exporters/openvino/model_patcher.py | 317 ++++++++++++++++++++ optimum/exporters/openvino/stateful.py | 27 ++ optimum/intel/openvino/modeling_decoder.py | 214 ++++++++++++- 4 files changed, 685 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 863c5e2383..40e80089f9 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -95,6 +95,7 @@ LlamaModelPatcher, LlavaImageEmbeddingModelPatcher, LlavaQwen2ImageEmbeddingsModelPatcher, + MambaPatcher, MiniCPM3Patcher, MiniCPMModelPatcher, MiniCPMVImageEmbeddingsModelPatcher, @@ -2880,3 +2881,132 @@ def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": return DeepseekPatcher(self, model, model_kwargs=model_kwargs) + + +class MambaCacheDummyInputGenerator(DummyInputGenerator): + """ + Generates dummy past_key_values inputs for seq2seq architectures. + """ + + SUPPORTED_INPUT_NAMES = ("past_ssm_states", "past_conv_states", "cache_position") + + def __init__( + self, + task: str, + normalized_config, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + **kwargs, + ): + self.normalized_config = normalized_config + self.batch_size = batch_size + self.sequence_length = sequence_length + self.intermediate_size = self.normalized_config.config.intermediate_size + self.ssm_state_size = self.normalized_config.config.state_size + self.conv_kernel_size = self.normalized_config.config.conv_kernel + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "past_ssm_states": + ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size] + return [ + self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype) + for _ in range(self.normalized_config.num_layers) + ] + + elif input_name == "past_conv_states": + conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size] + return [ + self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype) + for _ in range(self.normalized_config.num_layers) + ] + + elif input_name == "cache_position": + return self.random_int_tensor( + shape=[self.conv_kernel_size], + max_value=self.sequence_length, + framework=framework, + dtype=int_dtype, + ) + + raise ValueError(f"Unsupported input name {input_name}") + + +@register_in_tasks_manager("mamba", *["text-generation", "text-generation-with-past"], library_name="transformers") +class MambaOpenVINOConfig(TextDecoderOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator) + DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if self.use_past_in_inputs: + common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} + self.add_past_key_values(common_inputs, direction="inputs") + # common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} + common_inputs["cache_position"] = {0: "cache_sequence_length"} + else: + common_inputs = { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + # "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "cache_position": {0: "cache_sequence_length"}, + } + return common_inputs + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. + + Args: + inputs_or_outputs (`Dict[str, Dict[int, str]]`): + The mapping to fill. + direction (`str`): + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the + output mapping, this is important for axes naming. + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + ssm_name = "past_ssm_states" + conv_name = "past_conv_states" + else: + ssm_name = "present_ssm_states" + conv_name = "present_conv_states" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{ssm_name}.{i}"] = {0: "batch_size"} + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{conv_name}.{i}"] = {0: "batch_size"} + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ): + return MambaPatcher(self, model, model_kwargs) + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + input_names = [key for key in self.inputs.keys() if not key.startswith("past_")] + if self.use_past_in_inputs and self.use_cache_branch is not False: + input_names.extend(["past_ssm_states", "past_conv_states"]) + + for input_name in input_names: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, + input_name, + framework, + input_shapes=kwargs, + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' + ) + + return dummy_inputs diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index a28c8d5adc..33f49bbedd 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -4367,3 +4367,320 @@ def __init__( layer.mlp.down_proj.to(torch.float32) super().__init__(config, model, model_kwargs) + + +class ConvSequenceTransform(torch.nn.Module): + def __init__(self, conv_kernel_size, use_conv_bias, conv1, act, conv_bias): + super().__init__() + self.conv_kernel_size = conv_kernel_size + self.use_conv_bias = use_conv_bias + self.conv1d = conv1 + self.act = act + self.conv_bias = conv_bias + + def update_conv_state( + self, conv_state, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state_1 = conv_state.roll(shifts=-1, dims=-1) + upd_conv_state = conv_state_1.scatter(2, cache_position, new_conv_state) + return upd_conv_state + + def get_positions(self, conv_state, cache_position): + cache_position_clamped = cache_position.clamp(0, self.conv_kernel_size - 1) + positions = cache_position_clamped.expand(conv_state.shape[0], conv_state.shape[1], -1) + return positions + + def forward(self, hidden_states, cache_position, conv_state): + pad_value = (self.conv_kernel_size - hidden_states.shape[-1]) * ( + cache_position.shape[0] == self.conv_kernel_size + ) + new_conv_state = torch.nn.functional.pad(hidden_states, (pad_value, 0)) + upd_cache_position = self.get_positions(conv_state, cache_position) + upd_conv_state = self.update_conv_state(conv_state, new_conv_state, upd_cache_position) + if cache_position.shape[0] == self.conv_kernel_size: + hidden_states = self.conv1d(hidden_states)[:, :, : hidden_states.shape[-1]] + else: + hidden_states = torch.sum(upd_conv_state * self.conv1d.weight[:, 0, :], dim=-1) + hidden_states += self.conv_bias + hidden_states = hidden_states.unsqueeze(-1) + hidden_states = self.act(hidden_states) # [batch, intermediate_size, seq_len] + return hidden_states, upd_conv_state + + +class SelectiveScan(torch.nn.Module): + def forward(self, ssm, u, dt, A, B, C, D): + dA = torch.einsum("bld,dn->bldn", dt, A) + dB_u = torch.einsum("bld,bld,bln->bldn", dt, u, B) + dA_cumsum = torch.nn.functional.pad(dA[:, 1:], (0, 0, 0, 0, 0, 1)).flip(1).cumsum(1).exp().flip(1) + x = dB_u * dA_cumsum + (ssm.unsqueeze(1) * dA[:, :1].exp()) + x = x.cumsum(1) / (dA_cumsum + 1e-12) + y = torch.einsum("bldn,bln->bld", x, C) + return y + u * D, x[:, -1, :, :] + + +def mamba_mixer_forward( + self, + input_states, + cache_params=None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, +): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + hidden_states, conv_state = self.conv_sequence_transform( + hidden_states, cache_position, cache_params.conv_states[self.layer_idx] + ) + cache_params.conv_states[self.layer_idx] = conv_state + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + + # DIFF + # discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + + # # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + # A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + # discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] + # discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size] + # deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # # 3.c perform the recurrence y ← SSM(A, B, C)(x) + # if self.use_mambapy and self.training and cache_params is None: + # hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size] + + # scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + # scan_output = scan_output + hidden_states * self.D[None, :, None] + # scan_output = scan_output * self.act(gate) + # else: + # scan_outputs = [] + # for i in range(seq_len): + # ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] + # scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + # scan_outputs.append(scan_output[:, :, 0]) + # scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] + # scan_output = scan_output + (hidden_states * self.D[None, :, None]) + # scan_output = (scan_output * self.act(gate)) + + # if cache_params is not None: + # cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + discrete_time_step = torch.nn.functional.softplus(discrete_time_step) # [batch, intermediate_size, seq_len] + A = -torch.exp(self.A_log.float()) + B = B.float() + D = self.D.float() + + scan_output, ssm_state = self.selective_scan( + ssm_state, hidden_states.float().transpose(1, 2), discrete_time_step, A, B, C, D + ) + scan_output = scan_output.transpose(1, 2) + scan_output = scan_output * self.act(gate) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + + +class MambaPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + self._patching_specs = [] + from transformers.cache_utils import MambaCache + + class MambaCacheWrap(MambaCache): + def __init__( + self, + config: "PretrainedConfig", + batch_size: int = None, + dtype: torch.dtype = torch.float32, + device: Optional[Union[torch.device, str]] = None, + max_batch_size: Optional[int] = None, + conv_states: Optional[List[torch.Tensor]] = None, + ssm_states: Optional[List[torch.Tensor]] = None, + ): + self.dtype = dtype + self.max_batch_size = batch_size or max_batch_size + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.device = torch.device(device) if device is not None else torch.device("cpu") + # print(config.num_hidden_layers) + + if conv_states is not None: + self.conv_states = conv_states + else: + self.conv_states = [] + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=self.device, + dtype=dtype, + ) + self.conv_states.append(conv_state) + + if ssm_states is not None: + self.ssm_states = ssm_states + else: + self.ssm_states: List[torch.Tensor] = [] + for _ in range(config.num_hidden_layers): + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=self.device, + dtype=dtype, + ) + + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state = conv_state.scatter( + 2, cache_position.expand(conv_state.shape[0], conv_state.shape[1], -1), new_conv_state + ) + self.conv_states[layer_idx] = conv_state + return self.conv_states[layer_idx] + + def forward_wrap( + self, input_ids, attention_mask=None, cache_position=None, past_ssm_states=None, past_conv_states=None + ): + use_cache = False + cache_params = None + if past_ssm_states is not None and past_conv_states is not None: + use_cache = True + cache_params = MambaCacheWrap( + self.config, + input_ids.shape[0], + conv_states=list(past_conv_states), + ssm_states=list(past_ssm_states), + ) + result = self.__orig_forward( + input_ids=input_ids, cache_position=cache_position, cache_params=cache_params, use_cache=use_cache + ) + if use_cache: + return { + "logits": result.logits, + "ssm_states": result.cache_params.ssm_states, + "conv_states": result.cache_params.conv_states, + } + return result + + model.__orig_forward = model.forward + model.forward = types.MethodType(forward_wrap, model) + self._model = model + + self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call" + self.orig_forward = getattr(self._model, self.orig_forward_name) + + self.model_kwargs = model_kwargs if model_kwargs is not None else {} + self.real_config = config + + allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past + + @functools.wraps(self.orig_forward) + def patched_forward(*args, **kwargs): + signature = inspect.signature(self.orig_forward) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + + outputs = self.orig_forward(*args, **kwargs) + + # This code block handles different cases of the filterd_outputs input to align it with the expected + # format of outputs. It is common for the output type of a model to vary, such as tensor, list, + # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that + # contains the output names of the model. In the case of Timm classification models, the output + # is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config + # match the outputs in order. + filterd_outputs = {} + if isinstance(outputs, dict): + for name, value in outputs.items(): + output_name = config.torch_to_onnx_output_map.get(name, name) + if ( + output_name in config.outputs + or ( + allow_past_in_outputs and (name.startswith("ssm_states") or name.startswith("conv_states")) + ) + or any(key.startswith(output_name) for key in config.outputs.keys()) + ): + filterd_outputs[name] = value + elif isinstance(outputs, (list, tuple)): + outputs_list = list(config.outputs.keys()) + filterd_outputs = dict(zip(outputs_list, outputs)) + else: + if len(config.outputs) > 1: + num_outputs = len(config.outputs) + outputs_str = ", ".join(config.outputs.keys()) + raise ValueError( + f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}" + ) + else: + name = list(config.outputs.keys())[0] + filterd_outputs[name] = outputs + name = list(config.outputs.keys())[0] + filterd_outputs[name] = outputs + + return filterd_outputs + + self.patched_forward = patched_forward + + def __enter__(self): + super().__enter__() + selective_scan = SelectiveScan() + + for layer in self._model.backbone.layers: + layer.mixer.selective_scan = selective_scan + layer.mixer._orig_forward = layer.mixer.forward + layer.mixer.forward = types.MethodType(mamba_mixer_forward, layer.mixer) + conv_transform = ConvSequenceTransform( + layer.mixer.conv_kernel_size, + layer.mixer.use_conv_bias, + layer.mixer.conv1d, + layer.mixer.act, + layer.mixer.conv1d.bias, + ) + layer.mixer.conv_sequence_transform = torch.jit.script(conv_transform) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + for layer in self._model.backbone.layers: + layer.mixer.forward = layer.mixer._orig_forward diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index a367ea8f00..a28b0134fe 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -272,9 +272,36 @@ def insert_state_for_nodes(model: ov.Model, nodes): model.add_sinks([assign]) +def patch_stateful_ssm(config, ov_model): + cache_input_names = [key_name for key in ov_model.inputs for key_name in key.get_names() if "past_" in key_name] + cache_output_names = [ + key_name for key in ov_model.outputs for key_name in key.get_names() if "present" in key_name + ] + + print(cache_output_names) + print(ov_model.outputs) + if not cache_input_names or not cache_output_names: + return + + batch_dim = 0 + + from openvino._offline_transformations import apply_make_stateful_transformation + + input_output_map = {} + for cache_name_pair in zip(cache_input_names, cache_output_names): + input_output_map[cache_name_pair[0]] = cache_name_pair[1] + + print(input_output_map) + + apply_make_stateful_transformation(ov_model, input_output_map) + build_state_initializer(ov_model, batch_dim) + + def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"): return patch_stateful_encoder_decoder(config, ov_model) + if config.model_type == "mamba": + return patch_stateful_ssm(config, ov_model) return patch_stateful_decoder(config, ov_model) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index ebe634a54e..d5a6630f68 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -14,6 +14,7 @@ import copy import logging import os +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -55,6 +56,11 @@ ) +if is_transformers_version(">=", "4.43"): + from transformers.cache_utils import MambaCache +else: + MambaCache = object() + if TYPE_CHECKING: try: from transformers.generation.streamers import BaseStreamer @@ -95,6 +101,15 @@ """ +def has_cache_inputs(model): + return any( + "past_key_values" in key.get_any_name() + or "past_ssm" in key.get_any_name() + or "past_conv" in key.get_any_name() + for key in model.inputs + ) + + @add_start_docstrings( """ Base OVBaseDecoderModel class. @@ -139,13 +154,17 @@ def __init__( self.is_dynamic = dynamic_shapes use_cache = kwargs.pop("use_cache", True) model_has_sinks = model_has_state(self.model) - self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks + self.use_cache = has_cache_inputs(model) or model_has_sinks stateful = kwargs.pop("stateful", None) # stateful model only if it is converted with stateful=True self.stateful = model_has_sinks self.main_input_name = "input_ids" self.num_pkv = 2 self.key_value_input_names = [key for key in self.input_names if "key_values" in key] - self.key_value_output_names = [key for key in self.output_names if "present" in key] + self.key_value_output_names = [key for key in self.output_names if "present." in key] + self.ssm_cache_input_names = [key for key in self.input_names if "past_ssm_states" in key] + self.conv_cache_input_names = [key for key in self.input_names if "past_conv_states" in key] + self.ssm_cache_output_names = [key for key in self.output_names if "present_ssm_states" in key] + self.conv_cache_output_names = [key for key in self.output_names if "present_conv_states" in key] # Keeping the original model for serialization self._original_model = self.model.clone() if not compile_only else None self._pkv_precision = Type.f32 @@ -382,7 +401,7 @@ def _reshape( shapes[inputs][1] = -1 else: shapes[inputs][2] = -1 - elif input_name.startswith("beam_idx"): + elif input_name.startswith("beam_idx") or input_name.startswith("cache_position"): shapes[inputs][0] = -1 else: shapes[inputs][1] = -1 @@ -839,6 +858,8 @@ def _from_pretrained( init_cls = OVBloomForCausalLM elif model_type == "gpt-bigcode": init_cls = OVGPTBigCodeForCausalLM + elif model_type == "mamba": + init_cls = OVMambaForCausalLM else: init_cls = cls @@ -1014,3 +1035,190 @@ def _reorder_cache( return past_key_values else: return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values) + + +class OVMambaCache(MambaCache): + def __init__( + self, + config: "PretrainedConfig", + batch_size: int = None, + dtype: torch.dtype = torch.float32, + device: Optional[Union[torch.device, str]] = None, + max_batch_size: Optional[int] = None, + conv_states: Optional[List[torch.Tensor]] = None, + ssm_states: Optional[List[torch.Tensor]] = None, + ): + self.dtype = dtype + self.max_batch_size = batch_size or max_batch_size + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.device = torch.device(device) if device is not None else torch.device("cpu") + + if conv_states is not None: + self.conv_states = conv_states + else: + self.conv_states = [] + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, self.intermediate_size, self.conv_kernel_size, device=self.device, dtype=dtype + ) + self.conv_states.append(conv_state) + + if ssm_states is not None: + self.ssm_states = ssm_states + else: + self.ssm_states: List[torch.Tensor] = [] + for _ in range(config.num_hidden_layers): + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=self.device, + dtype=dtype, + ) + + self.ssm_states.append(ssm_state) + + +@dataclass +class MambaOutput(ModelOutput): + """ + Class for the MAMBA model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[OVMambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class OVMambaForCausalLM(OVModelForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + cache_params=None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ): + inputs = {"input_ids": input_ids.cpu().numpy()} + if "cache_position" in self.input_names: + inputs["cache_position"] = cache_position.cpu().numpy() + if "attention_mask" in self.input_names: + inputs["attention_mask"] = cache_position.cpu().numpy() + + if not self.stateful: + if cache_params is None and self.ssm_cache_input_names and self.conv_cache_input_names: + cache_params = OVMambaCache(self.config, input_ids.shape[0]) + ssm_cache = cache_params.ssm_states + conv_cache = cache_params.conv_states + + inputs.update(zip(self.ssm_cache_input_names, ssm_cache)) + inputs.update(zip(self.conv_cache_input_names, conv_cache)) + else: + if cache_params is None: + # This is the first iteration in a sequence, reset all states + if self.request is not None: + self.request.reset_state() + self._past_length = 0 + + ssm_states, conv_states = [], [] + print(inputs.keys()) + + self.request.start_async(inputs, share_inputs=True) + self.request.wait() + logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) + + if self.stateful: + self._past_length += input_ids.shape[1] + else: + print(self.ssm_cache_output_names) + print(self.conv_cache_output_names) + ssm_states = [self.request.get_tensor(key).data for key in self.ssm_cache_output_names] + conv_states = [self.request.get_tensor(key).data for key in self.conv_cache_output_names] + cache_params = OVMambaCache(self.config, input_ids.shape[0], conv_states=conv_states, ssm_states=ssm_states) + + return MambaOutput(logits=logits, cache_params=cache_params) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + print(model_kwargs["cache_params"]) + if ( + model_kwargs.get("use_cache", True) + and "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params=None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **kwargs, + ): + # Overwitten -- uses `cache_params` as opposed to `past_key_values` + + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + if cache_position[0] > 0: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if attention_mask is not None: + attention_mask = None + + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + "attention_mask": attention_mask, + } + ) + return model_inputs From 4178fa67f066ad3920a3e0d8a75efbf401f35a4b Mon Sep 17 00:00:00 2001 From: eaidova Date: Fri, 7 Mar 2025 19:43:53 +0400 Subject: [PATCH 2/2] add falcon mamba support --- optimum/exporters/openvino/model_configs.py | 11 +++ optimum/exporters/openvino/model_patcher.py | 88 +++++++++++++++++++++ optimum/exporters/openvino/stateful.py | 5 +- optimum/exporters/openvino/utils.py | 2 + optimum/intel/openvino/modeling_decoder.py | 12 +-- 5 files changed, 109 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 40e80089f9..a78da28472 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -77,6 +77,7 @@ DeciLMModelPatcher, DeepseekPatcher, FalconModelPatcher, + FalconMambaPatcher, FluxTransfromerModelPatcher, Gemma2ModelPatcher, GptBigCodeModelPatcher, @@ -3010,3 +3011,13 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): ) return dummy_inputs + + +@register_in_tasks_manager( + "falcon-mamba", *["text-generation", "text-generation-with-past"], library_name="transformers" +) +class FalconMambaOpenVINOConfig(MambaOpenVINOConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ): + return FalconMambaPatcher(self, model, model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 33f49bbedd..ec33514ca8 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -4510,6 +4510,75 @@ def mamba_mixer_forward( return contextualized_states +def falcon_mamba_mixer_forward( + self, + input_states, + cache_params=None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, +): + from transformers.models.falcon_mamba.modeling_falcon_mamba import rms_forward + + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + hidden_states, conv_state = self.conv_sequence_transform( + hidden_states, cache_position, cache_params.conv_states[self.layer_idx] + ) + cache_params.conv_states[self.layer_idx] = conv_state + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + + discrete_time_step = torch.nn.functional.softplus(discrete_time_step) # [batch, intermediate_size, seq_len] + A = -torch.exp(self.A_log.float()) + B = B.float() + D = self.D.float() + + scan_output, ssm_state = self.selective_scan( + ssm_state, hidden_states.float().transpose(1, 2), discrete_time_step, A, B, C, D + ) + scan_output = scan_output.transpose(1, 2) + scan_output = scan_output * self.act(gate) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + + class MambaPatcher(ModelPatcher): def __init__( self, @@ -4684,3 +4753,22 @@ def __exit__(self, exc_type, exc_value, traceback): self._model.forward = self._model.__orig_forward for layer in self._model.backbone.layers: layer.mixer.forward = layer.mixer._orig_forward + + +class FalconMambaPatcher(MambaPatcher): + def __enter__(self): + super().__enter__() + selective_scan = SelectiveScan() + + for layer in self._model.backbone.layers: + layer.mixer.selective_scan = selective_scan + layer.mixer._orig_forward = layer.mixer.forward + layer.mixer.forward = types.MethodType(falcon_mamba_mixer_forward, layer.mixer) + conv_transform = ConvSequenceTransform( + layer.mixer.conv_kernel_size, + layer.mixer.use_conv_bias, + layer.mixer.conv1d, + layer.mixer.act, + layer.mixer.conv1d.bias, + ) + layer.mixer.conv_sequence_transform = torch.jit.script(conv_transform) diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index a28b0134fe..70ee40ee8f 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -297,10 +297,13 @@ def patch_stateful_ssm(config, ov_model): build_state_initializer(ov_model, batch_dim) +SSM_MODELS = ["mamba", "falcon-mamba"] + + def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"): return patch_stateful_encoder_decoder(config, ov_model) - if config.model_type == "mamba": + if config.model_type.replace("_", "-") in SSM_MODELS: return patch_stateful_ssm(config, ov_model) return patch_stateful_decoder(config, ov_model) diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 72aa05b3b6..6aefbe828a 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -229,6 +229,8 @@ def get_submodels(model): "qwen2-5-vl", ] +SSM_MODELS = ["mamba", "falcon-mamba"] + def save_config(config, save_dir): try: diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index d5a6630f68..ee15716488 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -14,9 +14,9 @@ import copy import logging import os -from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass import numpy as np import openvino @@ -37,6 +37,7 @@ from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful from ...exporters.openvino.stateful import model_has_state +from ...exporters.openvino.utils import SSM_MODELS from ..utils.import_utils import compare_versions, is_nncf_available, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .configuration import ( @@ -59,7 +60,7 @@ if is_transformers_version(">=", "4.43"): from transformers.cache_utils import MambaCache else: - MambaCache = object() + MambaCache = object if TYPE_CHECKING: try: @@ -858,7 +859,7 @@ def _from_pretrained( init_cls = OVBloomForCausalLM elif model_type == "gpt-bigcode": init_cls = OVGPTBigCodeForCausalLM - elif model_type == "mamba": + elif model_type in SSM_MODELS: init_cls = OVMambaForCausalLM else: init_cls = cls @@ -1138,8 +1139,6 @@ def forward( self._past_length = 0 ssm_states, conv_states = [], [] - print(inputs.keys()) - self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) @@ -1147,8 +1146,6 @@ def forward( if self.stateful: self._past_length += input_ids.shape[1] else: - print(self.ssm_cache_output_names) - print(self.conv_cache_output_names) ssm_states = [self.request.get_tensor(key).data for key in self.ssm_cache_output_names] conv_states = [self.request.get_tensor(key).data for key in self.conv_cache_output_names] cache_params = OVMambaCache(self.config, input_ids.shape[0], conv_states=conv_states, ssm_states=ssm_states) @@ -1159,7 +1156,6 @@ def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs ) -> Dict[str, Any]: model_kwargs["cache_params"] = outputs.get("cache_params", None) - print(model_kwargs["cache_params"]) if ( model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs