@@ -1864,18 +1864,49 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
18641864 return self .random_int_tensor (shape , max_value = self .vocab_size , framework = framework , dtype = int_dtype )
18651865
18661866
1867+ class DummyUnetEncoderInputGenerator (DummySeq2SeqDecoderTextInputGenerator ):
1868+ def __init__ (
1869+ self ,
1870+ task : str ,
1871+ normalized_config : NormalizedTextConfig ,
1872+ batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
1873+ sequence_length : int = DEFAULT_DUMMY_SHAPES ["sequence_length" ],
1874+ num_choices : int = DEFAULT_DUMMY_SHAPES ["num_choices" ],
1875+ random_batch_size_range : Optional [Tuple [int , int ]] = None ,
1876+ random_sequence_length_range : Optional [Tuple [int , int ]] = None ,
1877+ random_num_choices_range : Optional [Tuple [int , int ]] = None ,
1878+ ** kwargs ,
1879+ ):
1880+ super ().__init__ (
1881+ task ,
1882+ normalized_config ,
1883+ batch_size = batch_size ,
1884+ sequence_length = sequence_length ,
1885+ num_choices = num_choices ,
1886+ random_batch_size_range = random_batch_size_range ,
1887+ random_sequence_length_range = random_sequence_length_range ,
1888+ random_num_choices_range = random_num_choices_range ,
1889+ ** kwargs ,
1890+ )
1891+ if hasattr (normalized_config .config , "model_max_length" ):
1892+ self .sequence_length = normalized_config .config .model_max_length
1893+
1894+
18671895@register_in_tasks_manager ("unet" , * ["semantic-segmentation" ], library_name = "diffusers" )
18681896@register_in_tasks_manager ("unet-2d-condition" , * ["semantic-segmentation" ], library_name = "diffusers" )
18691897class UNetOpenVINOConfig (UNetOnnxConfig ):
18701898 DUMMY_INPUT_GENERATOR_CLASSES = (
18711899 DummyUnetVisionInputGenerator ,
18721900 DummyUnetTimestepInputGenerator ,
1873- ) + UNetOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES [2 :]
1901+ DummyUnetEncoderInputGenerator ,
1902+ )
18741903
18751904 @property
18761905 def inputs (self ) -> Dict [str , Dict [int , str ]]:
18771906 common_inputs = super ().inputs
18781907 common_inputs ["timestep" ] = {0 : "batch_size" }
1908+ if hasattr (self ._normalized_config .config , "model_max_length" ):
1909+ common_inputs ["encoder_hidden_states" ] = {0 : "batch_size" }
18791910 return common_inputs
18801911
18811912
0 commit comments