|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import enum |
| 16 | +import logging |
16 | 17 | from copy import deepcopy |
17 | 18 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
18 | 19 |
|
|
138 | 139 | QwenModelPatcher, |
139 | 140 | SanaTextEncoderModelPatcher, |
140 | 141 | XverseModelPatcher, |
| 142 | + Zamba2ModelPatcher, |
141 | 143 | ) |
142 | 144 |
|
143 | 145 |
|
| 146 | +logger = logging.getLogger(__name__) |
| 147 | + |
144 | 148 | if TYPE_CHECKING: |
145 | 149 | from transformers.modeling_utils import PreTrainedModel # noqa: F811 |
146 | 150 |
|
@@ -4278,3 +4282,107 @@ class GPT2OpenVINOConfig(GPT2OnnxConfig): |
4278 | 4282 | ) |
4279 | 4283 | class VisionEncoderDecoderOpenVINOConfig(VisionEncoderDecoderOnnxConfig): |
4280 | 4284 | _MODEL_PATCHER = OVSeq2SeqModelPatcher |
| 4285 | + |
| 4286 | + |
| 4287 | +class Zamba2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): |
| 4288 | + """ |
| 4289 | + Generates dummy cache_params inputs for Zamba2 architectures. |
| 4290 | + """ |
| 4291 | + |
| 4292 | + SUPPORTED_INPUT_NAMES = ("cache_params",) |
| 4293 | + |
| 4294 | + def __init__( |
| 4295 | + self, |
| 4296 | + task: str, |
| 4297 | + normalized_config, |
| 4298 | + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
| 4299 | + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], |
| 4300 | + **kwargs, |
| 4301 | + ): |
| 4302 | + super().__init__( |
| 4303 | + task=task, |
| 4304 | + normalized_config=normalized_config, |
| 4305 | + batch_size=batch_size, |
| 4306 | + sequence_length=sequence_length, |
| 4307 | + **kwargs, |
| 4308 | + ) |
| 4309 | + |
| 4310 | + config = normalized_config.config |
| 4311 | + self.intermediate_size = int(config.mamba_expand * config.hidden_size) |
| 4312 | + self.ssm_state_size = config.mamba_d_state |
| 4313 | + self.conv_kernel_size = config.mamba_d_conv |
| 4314 | + self.n_mamba_heads = config.n_mamba_heads |
| 4315 | + self.mamba_ngroups = config.mamba_ngroups |
| 4316 | + self.mamba_d_state = config.mamba_d_state |
| 4317 | + self.mamba_headdim = config.mamba_headdim |
| 4318 | + self.head_dim = config.attention_head_dim |
| 4319 | + self.hybrid_layer_ids = config.hybrid_layer_ids |
| 4320 | + logger.warning( |
| 4321 | + "The current support for the 'Zamba2' model type is experimental. " |
| 4322 | + "Performance is not optimal with high memory consumption. " |
| 4323 | + "Optimizations and improved support will be available in a future OpenVINO release." |
| 4324 | + ) |
| 4325 | + |
| 4326 | + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
| 4327 | + past_key_values = [] |
| 4328 | + for i in range(self.num_layers): |
| 4329 | + conv_state_shape = ( |
| 4330 | + self.batch_size, |
| 4331 | + self.intermediate_size + 2 * self.mamba_ngroups * self.mamba_d_state, |
| 4332 | + self.conv_kernel_size, |
| 4333 | + ) |
| 4334 | + conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype) |
| 4335 | + past_key_values.append(conv_state) |
| 4336 | + ssm_state_shape = (self.batch_size, self.n_mamba_heads, self.mamba_headdim, self.ssm_state_size) |
| 4337 | + ssm_state = self.random_float_tensor(ssm_state_shape, framework=framework, dtype=float_dtype) |
| 4338 | + past_key_values.append(ssm_state) |
| 4339 | + |
| 4340 | + for i in range(len(self.hybrid_layer_ids)): |
| 4341 | + kv_shape = (self.batch_size, self.num_attention_heads, self.sequence_length, self.head_dim) |
| 4342 | + k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype) |
| 4343 | + v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype) |
| 4344 | + past_key_values.append(k) |
| 4345 | + past_key_values.append(v) |
| 4346 | + |
| 4347 | + return past_key_values |
| 4348 | + |
| 4349 | + |
| 4350 | +@register_in_tasks_manager("zamba2", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 4351 | +class Zamba2OpenVINOConfig(MambaOpenVINOConfig): |
| 4352 | + PAD_ATTENTION_MASK_TO_PAST = False |
| 4353 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Zamba2DummyPastKeyValuesGenerator) |
| 4354 | + DUMMY_PKV_GENERATOR_CLASS = Zamba2DummyPastKeyValuesGenerator |
| 4355 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 4356 | + MIN_TRANSFORMERS_VERSION = "4.49.0" |
| 4357 | + _MODEL_PATCHER = Zamba2ModelPatcher |
| 4358 | + |
| 4359 | + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): |
| 4360 | + if direction not in ["inputs", "outputs"]: |
| 4361 | + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') |
| 4362 | + |
| 4363 | + if direction == "inputs": |
| 4364 | + decoder_sequence_name = "past_sequence_length" |
| 4365 | + cache_name_prefix = "cache_params.past" |
| 4366 | + else: |
| 4367 | + decoder_sequence_name = "past_sequence_length + sequence_length" |
| 4368 | + cache_name_prefix = "cache_params.present" |
| 4369 | + |
| 4370 | + for i in range(self._normalized_config.num_layers): |
| 4371 | + # [batch_size, conv_kernel_size - 1, d_model] |
| 4372 | + inputs_or_outputs[f"{cache_name_prefix}.conv.{i}"] = {0: "batch_size"} |
| 4373 | + # [batch_size, d_state, d_model] |
| 4374 | + inputs_or_outputs[f"{cache_name_prefix}.ssm.{i}"] = {0: "batch_size"} |
| 4375 | + |
| 4376 | + for i in range(len(self._normalized_config.hybrid_layer_ids)): |
| 4377 | + inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name} |
| 4378 | + inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name} |
| 4379 | + |
| 4380 | + @property |
| 4381 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 4382 | + common_inputs = { |
| 4383 | + "input_ids": {0: "batch_size", 1: "sequence_length"}, |
| 4384 | + "attention_mask": {0: "batch_size", 1: "sequence_length"}, |
| 4385 | + } |
| 4386 | + if self.use_past_in_inputs: |
| 4387 | + self.add_past_key_values(common_inputs, direction="inputs") |
| 4388 | + return common_inputs |
0 commit comments