Skip to content

Commit 81f28fd

Browse files
rkazantsecharlaixnikita-savelyevv
authored
Add OpenVINO Zamba2 support (#1354)
* [OpenVINO] Support Zamba2 by OpenVINO Signed-off-by: Kazantsev, Roman <[email protected]> * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Revert changes in notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb Signed-off-by: Kazantsev, Roman <[email protected]> * Add tests Signed-off-by: Kazantsev, Roman <[email protected]> * Fix formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Re-implement exporting Zamba2 model Signed-off-by: Kazantsev, Roman <[email protected]> * Fix export_cli_int8 test Signed-off-by: Kazantsev, Roman <[email protected]> * Apply suggestion from @rkazants * Apply suggestion from @rkazants * Apply suggestion from @rkazants * Update optimum/exporters/openvino/model_configs.py Co-authored-by: Ella Charlaix <[email protected]> * Update optimum/exporters/openvino/model_configs.py Co-authored-by: Ella Charlaix <[email protected]> * Update tests/openvino/test_exporters_cli.py * Apply suggestion from @rkazants * Apply suggestion from @rkazants * Fix formatting Signed-off-by: Kazantsev, Roman <[email protected]> * ^^X ^X Revert "" This reverts commit b11d517. * Introduce hybrid cache for both mamba and zamba2 models * Handle hybrid cache * Fix model config to set correct dimension for sequence length * Add patching for zamba2 mamba mixer * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Correct utils_tests Signed-off-by: Kazantsev, Roman <[email protected]> * Remove TFPretrainedModel * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Remove unused variables * Fix 4.45 * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Fix patch for zamba2 mamba mixer * Simplify vars names * Simplify var names * Clarify patch for zamba2 mamba mixer * Finalize patcher and add comments * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Update tests/openvino/utils_tests.py Co-authored-by: Ella Charlaix <[email protected]> * Correct names for mamba inference classes and add decoder test * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Correct patching * Fix export for None cache_params * Apply formatting * Apply code-review feedback: optimize Zamba2DummyPastKeyValuesGenerator * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Correct patching * Optimize Zamba2OpenVINOConfig * Align cache names with Mamba model * Update number of sdpas in test_decoder * Fix exit for patcher * Complete cache reporting result and refactor cache names * Apply suggestions from code review Co-authored-by: Nikita Savelyev <[email protected]> * Address comment for DummyPKVGenerator * Support not stateful with use_cache * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Add a comment about the patch * Update optimum/intel/openvino/modeling_decoder.py Co-authored-by: Nikita Savelyev <[email protected]> * Recover OVMambaForCausalLM with deprecation * Remove postion_ids input * Remove redundant keys and values in cache * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Update expected sdpa number * Fix exporting for any tiny model * Use text-generation-with-past task by default for Zamba2 * Add a warning about zamba2 support with poor performance * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Fix inference for stateless with use_cache * Apply formatting Signed-off-by: Kazantsev, Roman <[email protected]> --------- Signed-off-by: Kazantsev, Roman <[email protected]> Co-authored-by: Ella Charlaix <[email protected]> Co-authored-by: Nikita Savelyev <[email protected]>
1 parent d0a3920 commit 81f28fd

File tree

11 files changed

+803
-58
lines changed

11 files changed

+803
-58
lines changed

docs/source/openvino/models.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ Here is the list of the supported architectures :
148148
- XLM
149149
- XLM-Roberta
150150
- XVERSE
151+
- Zamba2
151152

152153
## [Diffusers](https://huggingface.co/docs/diffusers/index)
153154
- Stable Diffusion

optimum/exporters/openvino/__main__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def infer_task(
9393
except KeyError as e:
9494
try:
9595
config = AutoConfig.from_pretrained(model_name_or_path)
96-
if "MistralForCausalLM" in config.architectures:
96+
with_past_arch_list = ["MistralForCausalLM", "Zamba2ForCausalLM"]
97+
if any(arch in config.architectures for arch in with_past_arch_list):
9798
task = "text-generation-with-past"
9899
except Exception:
99100
raise KeyError(

optimum/exporters/openvino/model_configs.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import enum
16+
import logging
1617
from copy import deepcopy
1718
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1819

@@ -138,9 +139,12 @@
138139
QwenModelPatcher,
139140
SanaTextEncoderModelPatcher,
140141
XverseModelPatcher,
142+
Zamba2ModelPatcher,
141143
)
142144

143145

146+
logger = logging.getLogger(__name__)
147+
144148
if TYPE_CHECKING:
145149
from transformers.modeling_utils import PreTrainedModel # noqa: F811
146150

@@ -4278,3 +4282,107 @@ class GPT2OpenVINOConfig(GPT2OnnxConfig):
42784282
)
42794283
class VisionEncoderDecoderOpenVINOConfig(VisionEncoderDecoderOnnxConfig):
42804284
_MODEL_PATCHER = OVSeq2SeqModelPatcher
4285+
4286+
4287+
class Zamba2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
4288+
"""
4289+
Generates dummy cache_params inputs for Zamba2 architectures.
4290+
"""
4291+
4292+
SUPPORTED_INPUT_NAMES = ("cache_params",)
4293+
4294+
def __init__(
4295+
self,
4296+
task: str,
4297+
normalized_config,
4298+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
4299+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
4300+
**kwargs,
4301+
):
4302+
super().__init__(
4303+
task=task,
4304+
normalized_config=normalized_config,
4305+
batch_size=batch_size,
4306+
sequence_length=sequence_length,
4307+
**kwargs,
4308+
)
4309+
4310+
config = normalized_config.config
4311+
self.intermediate_size = int(config.mamba_expand * config.hidden_size)
4312+
self.ssm_state_size = config.mamba_d_state
4313+
self.conv_kernel_size = config.mamba_d_conv
4314+
self.n_mamba_heads = config.n_mamba_heads
4315+
self.mamba_ngroups = config.mamba_ngroups
4316+
self.mamba_d_state = config.mamba_d_state
4317+
self.mamba_headdim = config.mamba_headdim
4318+
self.head_dim = config.attention_head_dim
4319+
self.hybrid_layer_ids = config.hybrid_layer_ids
4320+
logger.warning(
4321+
"The current support for the 'Zamba2' model type is experimental. "
4322+
"Performance is not optimal with high memory consumption. "
4323+
"Optimizations and improved support will be available in a future OpenVINO release."
4324+
)
4325+
4326+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
4327+
past_key_values = []
4328+
for i in range(self.num_layers):
4329+
conv_state_shape = (
4330+
self.batch_size,
4331+
self.intermediate_size + 2 * self.mamba_ngroups * self.mamba_d_state,
4332+
self.conv_kernel_size,
4333+
)
4334+
conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype)
4335+
past_key_values.append(conv_state)
4336+
ssm_state_shape = (self.batch_size, self.n_mamba_heads, self.mamba_headdim, self.ssm_state_size)
4337+
ssm_state = self.random_float_tensor(ssm_state_shape, framework=framework, dtype=float_dtype)
4338+
past_key_values.append(ssm_state)
4339+
4340+
for i in range(len(self.hybrid_layer_ids)):
4341+
kv_shape = (self.batch_size, self.num_attention_heads, self.sequence_length, self.head_dim)
4342+
k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
4343+
v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
4344+
past_key_values.append(k)
4345+
past_key_values.append(v)
4346+
4347+
return past_key_values
4348+
4349+
4350+
@register_in_tasks_manager("zamba2", *["text-generation", "text-generation-with-past"], library_name="transformers")
4351+
class Zamba2OpenVINOConfig(MambaOpenVINOConfig):
4352+
PAD_ATTENTION_MASK_TO_PAST = False
4353+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Zamba2DummyPastKeyValuesGenerator)
4354+
DUMMY_PKV_GENERATOR_CLASS = Zamba2DummyPastKeyValuesGenerator
4355+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
4356+
MIN_TRANSFORMERS_VERSION = "4.49.0"
4357+
_MODEL_PATCHER = Zamba2ModelPatcher
4358+
4359+
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
4360+
if direction not in ["inputs", "outputs"]:
4361+
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
4362+
4363+
if direction == "inputs":
4364+
decoder_sequence_name = "past_sequence_length"
4365+
cache_name_prefix = "cache_params.past"
4366+
else:
4367+
decoder_sequence_name = "past_sequence_length + sequence_length"
4368+
cache_name_prefix = "cache_params.present"
4369+
4370+
for i in range(self._normalized_config.num_layers):
4371+
# [batch_size, conv_kernel_size - 1, d_model]
4372+
inputs_or_outputs[f"{cache_name_prefix}.conv.{i}"] = {0: "batch_size"}
4373+
# [batch_size, d_state, d_model]
4374+
inputs_or_outputs[f"{cache_name_prefix}.ssm.{i}"] = {0: "batch_size"}
4375+
4376+
for i in range(len(self._normalized_config.hybrid_layer_ids)):
4377+
inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name}
4378+
inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name}
4379+
4380+
@property
4381+
def inputs(self) -> Dict[str, Dict[int, str]]:
4382+
common_inputs = {
4383+
"input_ids": {0: "batch_size", 1: "sequence_length"},
4384+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
4385+
}
4386+
if self.use_past_in_inputs:
4387+
self.add_past_key_values(common_inputs, direction="inputs")
4388+
return common_inputs

0 commit comments

Comments
 (0)