|
17 | 17 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
18 | 18 |
|
19 | 19 | from packaging import version |
20 | | -from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel |
| 20 | +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel |
21 | 21 | from transformers.utils import is_tf_available |
22 | 22 |
|
23 | 23 | from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig |
|
75 | 75 | JaisModelPatcher, |
76 | 76 | LlamaModelPatcher, |
77 | 77 | LlavaImageEmbeddingModelPatcher, |
| 78 | + LlavaQwen2ImageEmbeddingsModelPatcher, |
78 | 79 | MiniCPMVImageEmbeddingsModelPatcher, |
79 | 80 | MiniCPMVResamplerModelPatcher, |
80 | 81 | MistralModelPatcher, |
@@ -1579,6 +1580,165 @@ def patch_model_for_export( |
1579 | 1580 | return InternVLChatImageEmbeddingModelPatcher(self, model, model_kwargs) |
1580 | 1581 |
|
1581 | 1582 |
|
| 1583 | +@register_in_tasks_manager( |
| 1584 | + "llava-qwen2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers" |
| 1585 | +) |
| 1586 | +class LlavaQwen2OpenVINOConfig(OnnxConfig): |
| 1587 | + SUPPORTS_PAST = True |
| 1588 | + MIN_TRANSFORMERS_VERSION = version.parse("4.40.0") |
| 1589 | + SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaConfigBehavior] |
| 1590 | + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig |
| 1591 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) |
| 1592 | + |
| 1593 | + def __init__( |
| 1594 | + self, |
| 1595 | + config: "PretrainedConfig", |
| 1596 | + task: str = "feature-extraction", |
| 1597 | + int_dtype: str = "int64", |
| 1598 | + float_dtype: str = "fp32", |
| 1599 | + behavior: LlavaConfigBehavior = LlavaConfigBehavior.VISION_EMBEDDINGS, |
| 1600 | + preprocessors: Optional[List[Any]] = None, |
| 1601 | + use_past: bool = False, |
| 1602 | + ): |
| 1603 | + self._behavior = behavior |
| 1604 | + self._orig_config = config |
| 1605 | + if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: |
| 1606 | + config = AutoConfig.from_pretrained(config.mm_vision_tower, trust_remote_code=True) |
| 1607 | + if hasattr(config, "vision_config"): |
| 1608 | + config = config.vision_config |
| 1609 | + super().__init__( |
| 1610 | + config=config, |
| 1611 | + task=task, |
| 1612 | + int_dtype=int_dtype, |
| 1613 | + float_dtype=float_dtype, |
| 1614 | + preprocessors=preprocessors, |
| 1615 | + ) |
| 1616 | + |
| 1617 | + @property |
| 1618 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 1619 | + if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: |
| 1620 | + return {} |
| 1621 | + return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}} |
| 1622 | + |
| 1623 | + @property |
| 1624 | + def outputs(self) -> Dict[str, Dict[int, str]]: |
| 1625 | + if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: |
| 1626 | + return {} |
| 1627 | + return {"last_hidden_state": {0: "batch_size"}} |
| 1628 | + |
| 1629 | + def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]): |
| 1630 | + if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): |
| 1631 | + behavior = LlavaConfigBehavior(behavior) |
| 1632 | + |
| 1633 | + if behavior == LlavaConfigBehavior.LANGUAGE: |
| 1634 | + model.forward = super(type(model), model).forward |
| 1635 | + return model |
| 1636 | + |
| 1637 | + if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: |
| 1638 | + return model |
| 1639 | + |
| 1640 | + if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS: |
| 1641 | + text_embedding = model.model.embed_tokens |
| 1642 | + text_embedding.config = model.model.config |
| 1643 | + return text_embedding |
| 1644 | + |
| 1645 | + def with_behavior( |
| 1646 | + self, |
| 1647 | + behavior: Union[str, LlavaConfigBehavior], |
| 1648 | + ): |
| 1649 | + """ |
| 1650 | + Creates a config for different behaviour. |
| 1651 | + Args: |
| 1652 | + behavior ([`ConfigBehavior`]): |
| 1653 | + The behavior to use for the new instance. |
| 1654 | + """ |
| 1655 | + if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior): |
| 1656 | + behavior = LlavaConfigBehavior(behavior) |
| 1657 | + |
| 1658 | + if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS: |
| 1659 | + model_type = self._orig_config.model_type.replace("llava-", "") |
| 1660 | + model_type = model_type.replace("_", "-") |
| 1661 | + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: |
| 1662 | + raise ValueError( |
| 1663 | + f"Unsupported language model type provided `{model_type}`. Please define custom export config" |
| 1664 | + ) |
| 1665 | + |
| 1666 | + if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: |
| 1667 | + raise ValueError( |
| 1668 | + f"Export config for text generation for `{model_type}` is not available. Please define custom export config" |
| 1669 | + ) |
| 1670 | + internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ |
| 1671 | + "text-generation-with-past" |
| 1672 | + ] |
| 1673 | + internal_export_config = internal_export_config_class( |
| 1674 | + self._orig_config, |
| 1675 | + use_past=True, |
| 1676 | + use_past_in_inputs=True, |
| 1677 | + int_dtype=self.int_dtype, |
| 1678 | + float_dtype=self.float_dtype, |
| 1679 | + ) |
| 1680 | + InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS |
| 1681 | + export_config = InputEmbedOpenvVINOConfig( |
| 1682 | + self._orig_config, |
| 1683 | + task="feature-extraction", |
| 1684 | + int_dtype=self.int_dtype, |
| 1685 | + float_dtype=self.float_dtype, |
| 1686 | + ) |
| 1687 | + return export_config |
| 1688 | + |
| 1689 | + if behavior == LlavaConfigBehavior.LANGUAGE: |
| 1690 | + model_type = self._orig_config.model_type.replace("llava-", "") |
| 1691 | + model_type = model_type.replace("_", "-") |
| 1692 | + |
| 1693 | + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: |
| 1694 | + raise ValueError( |
| 1695 | + f"Unsupported language model type provided `{model_type}`. Please define custom export config" |
| 1696 | + ) |
| 1697 | + |
| 1698 | + if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]: |
| 1699 | + raise ValueError( |
| 1700 | + f"Export config for text generation for `{model_type}` is not available. Please define custom export config" |
| 1701 | + ) |
| 1702 | + internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][ |
| 1703 | + "text-generation-with-past" |
| 1704 | + ] |
| 1705 | + internal_export_config = internal_export_config_class( |
| 1706 | + self._orig_config, |
| 1707 | + use_past=True, |
| 1708 | + use_past_in_inputs=True, |
| 1709 | + int_dtype=self.int_dtype, |
| 1710 | + float_dtype=self.float_dtype, |
| 1711 | + ) |
| 1712 | + export_config = LMInputEmbedsConfigHelper(internal_export_config) |
| 1713 | + export_config._normalized_config = internal_export_config._normalized_config |
| 1714 | + return export_config |
| 1715 | + |
| 1716 | + if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: |
| 1717 | + return self.__class__( |
| 1718 | + self._orig_config, |
| 1719 | + task=self.task, |
| 1720 | + int_dtype=self.int_dtype, |
| 1721 | + float_dtype=self.float_dtype, |
| 1722 | + behavior=behavior, |
| 1723 | + preprocessors=self._preprocessors, |
| 1724 | + ) |
| 1725 | + |
| 1726 | + def patch_model_for_export( |
| 1727 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 1728 | + ): |
| 1729 | + model_kwargs = model_kwargs or {} |
| 1730 | + if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS: |
| 1731 | + return super().patch_model_for_export(model, model_kwargs) |
| 1732 | + return LlavaQwen2ImageEmbeddingsModelPatcher(self, model, model_kwargs) |
| 1733 | + |
| 1734 | + def rename_ambiguous_inputs(self, inputs): |
| 1735 | + if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS: |
| 1736 | + model_inputs = {} |
| 1737 | + model_inputs["images"] = inputs["pixel_values"] |
| 1738 | + return model_inputs |
| 1739 | + return super().rename_ambiguous_inputs(inputs) |
| 1740 | + |
| 1741 | + |
1582 | 1742 | class PooledProjectionsDummyInputGenerator(DummyInputGenerator): |
1583 | 1743 | SUPPORTED_INPUT_NAMES = ["pooled_projections"] |
1584 | 1744 |
|
|
0 commit comments