|  | 
| 39 | 39 |     MPTOnnxConfig, | 
| 40 | 40 |     PhiOnnxConfig, | 
| 41 | 41 |     UNetOnnxConfig, | 
|  | 42 | +    VaeEncoderOnnxConfig, | 
| 42 | 43 |     VisionOnnxConfig, | 
| 43 | 44 | ) | 
| 44 | 45 | from optimum.exporters.onnx.model_patcher import ModelPatcher | 
|  | 
| 54 | 55 |     DummyVisionInputGenerator, | 
| 55 | 56 |     FalconDummyPastKeyValuesGenerator, | 
| 56 | 57 |     MistralDummyPastKeyValuesGenerator, | 
| 57 |  | -    DummySeq2SeqDecoderTextInputGenerator | 
| 58 | 58 | ) | 
| 59 | 59 | from optimum.utils.normalized_config import NormalizedConfig, NormalizedTextConfig, NormalizedVisionConfig | 
| 60 | 60 | 
 | 
| @@ -1889,52 +1889,78 @@ def rename_ambiguous_inputs(self, inputs): | 
| 1889 | 1889 | class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig): | 
| 1890 | 1890 |     pass | 
| 1891 | 1891 | 
 | 
|  | 1892 | + | 
| 1892 | 1893 | @register_in_tasks_manager("gemma2-text-encoder", *["feature-extraction"], library_name="diffusers") | 
| 1893 | 1894 | class Gemma2TextEncoderOpenVINOConfig(CLIPTextOpenVINOConfig): | 
| 1894 | 1895 |     @property | 
| 1895 | 1896 |     def inputs(self) -> Dict[str, Dict[int, str]]: | 
| 1896 | 1897 |         return { | 
| 1897 | 1898 |             "input_ids": {0: "batch_size", 1: "sequence_length"}, | 
| 1898 |  | -            "attention_mask": {0: "batch_size", 1: "sequence_length"} | 
|  | 1899 | +            "attention_mask": {0: "batch_size", 1: "sequence_length"}, | 
| 1899 | 1900 |         } | 
| 1900 | 1901 | 
 | 
| 1901 | 1902 | 
 | 
| 1902 |  | -class DummySeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator): | 
|  | 1903 | +class DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator): | 
| 1903 | 1904 |     SUPPORTED_INPUT_NAMES = ( | 
| 1904 | 1905 |         "decoder_input_ids", | 
| 1905 | 1906 |         "decoder_attention_mask", | 
| 1906 | 1907 |         "encoder_outputs", | 
| 1907 | 1908 |         "encoder_hidden_states", | 
| 1908 |  | -        "encoder_attention_mask" | 
|  | 1909 | +        "encoder_attention_mask", | 
| 1909 | 1910 |     ) | 
| 1910 | 1911 | 
 | 
| 1911 | 1912 | 
 | 
| 1912 |  | -class DummySanaTransformerVisionInputGenerator(DummyVisionInputGenerator): | 
| 1913 |  | -    def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): | 
| 1914 |  | -        if input_name not in ["sample", "latent_sample"]: | 
| 1915 |  | -            return super().generate(input_name, framework, int_dtype, float_dtype) | 
| 1916 |  | -        return self.random_float_tensor( | 
| 1917 |  | -            shape=[self.batch_size, self.num_channels, self.height, self.width], | 
| 1918 |  | -            framework=framework, | 
| 1919 |  | -            dtype=float_dtype, | 
| 1920 |  | -        ) | 
|  | 1913 | +class DummySanaTransformerVisionInputGenerator(DummyUnetVisionInputGenerator): | 
|  | 1914 | +    def __init__( | 
|  | 1915 | +        self, | 
|  | 1916 | +        task: str, | 
|  | 1917 | +        normalized_config: NormalizedVisionConfig, | 
|  | 1918 | +        batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], | 
|  | 1919 | +        num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], | 
|  | 1920 | +        width: int = DEFAULT_DUMMY_SHAPES["width"] // 8, | 
|  | 1921 | +        height: int = DEFAULT_DUMMY_SHAPES["height"] // 8, | 
|  | 1922 | +        # Reduce img shape by 4 for FLUX to reduce memory usage on conversion | 
|  | 1923 | +        **kwargs, | 
|  | 1924 | +    ): | 
|  | 1925 | +        super().__init__(task, normalized_config, batch_size, num_channels, width=width, height=height, **kwargs) | 
|  | 1926 | + | 
| 1921 | 1927 | 
 | 
| 1922 | 1928 | @register_in_tasks_manager("sana-transformer", *["semantic-segmentation"], library_name="diffusers") | 
| 1923 | 1929 | class SanaTransformerOpenVINOConfig(UNetOpenVINOConfig): | 
| 1924 | 1930 |     NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( | 
| 1925 | 1931 |         image_size="sample_size", | 
| 1926 | 1932 |         num_channels="in_channels", | 
| 1927 |  | -        hidden_size="cross_attention_dim", | 
|  | 1933 | +        hidden_size="caption_channels", | 
| 1928 | 1934 |         vocab_size="attention_head_dim", | 
| 1929 | 1935 |         allow_new=True, | 
| 1930 | 1936 |     ) | 
| 1931 |  | -    DUMMY_INPUT_GENERATOR_CLASSES = (DummySanaTransformerVisionInputGenerator, DummySeq2SeqDecoderTextWithEncMaskInputGenerator) + UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:-1] | 
|  | 1937 | +    DUMMY_INPUT_GENERATOR_CLASSES = ( | 
|  | 1938 | +        DummySanaTransformerVisionInputGenerator, | 
|  | 1939 | +        DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator, | 
|  | 1940 | +    ) + UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:-1] | 
|  | 1941 | + | 
| 1932 | 1942 |     @property | 
| 1933 | 1943 |     def inputs(self): | 
| 1934 | 1944 |         common_inputs = super().inputs | 
| 1935 | 1945 |         common_inputs["encoder_attention_mask"] = {0: "batch_size", 1: "sequence_length"} | 
| 1936 | 1946 |         return common_inputs | 
| 1937 | 1947 | 
 | 
|  | 1948 | +    def rename_ambiguous_inputs(self, inputs): | 
|  | 1949 | +        #  The input name in the model signature is `x, hence the export input name is updated. | 
|  | 1950 | +        hidden_states = inputs.pop("sample", None) | 
|  | 1951 | +        if hidden_states is not None: | 
|  | 1952 | +            inputs["hidden_states"] = hidden_states | 
|  | 1953 | +        return inputs | 
|  | 1954 | + | 
|  | 1955 | + | 
|  | 1956 | +@register_in_tasks_manager("dcae-encoder", *["semantic-segmentation"], library_name="diffusers") | 
|  | 1957 | +class DcaeEncoderOpenVINOConfig(VaeEncoderOnnxConfig): | 
|  | 1958 | +    @property | 
|  | 1959 | +    def outputs(self) -> Dict[str, Dict[int, str]]: | 
|  | 1960 | +        return { | 
|  | 1961 | +            "latent": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, | 
|  | 1962 | +        } | 
|  | 1963 | + | 
| 1938 | 1964 | 
 | 
| 1939 | 1965 | class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator): | 
| 1940 | 1966 |     SUPPORTED_INPUT_NAMES = ( | 
|  | 
0 commit comments