2222
2323from optimum .exporters .onnx .base import ConfigBehavior , OnnxConfig , OnnxConfigWithPast , OnnxSeq2SeqConfigWithPast
2424from optimum .exporters .onnx .constants import ONNX_DECODER_MERGED_NAME , ONNX_DECODER_NAME , ONNX_DECODER_WITH_PAST_NAME
25+ from optimum .exporters .tasks import TasksManager
26+ from optimum .onnx import merge_decoders
2527from optimum .utils import (
2628 DummyAudioInputGenerator ,
2729 DummyBboxInputGenerator ,
3133 DummySeq2SeqPastKeyValuesGenerator ,
3234 DummyTextInputGenerator ,
3335 DummyVisionInputGenerator ,
34- is_diffusers_available ,
3536 logging ,
3637)
3738
3839
39- # TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization
40-
41-
4240if TYPE_CHECKING :
4341 from transformers import PretrainedConfig , PreTrainedModel
4442
45- if is_diffusers_available ():
46- from diffusers import ModelMixin
4743
4844logger = logging .get_logger (__name__ )
4945
@@ -110,7 +106,7 @@ def outputs(self) -> dict[str, dict[int, str]]:
110106 def post_process_exported_models (
111107 self ,
112108 path : Path ,
113- models_and_onnx_configs : dict [str , tuple [PreTrainedModel | ModelMixin , OnnxConfig ]],
109+ models_and_onnx_configs : dict [str , tuple [PreTrainedModel , OnnxConfig ]],
114110 onnx_files_subpaths : list [str ],
115111 ):
116112 models_and_onnx_configs , onnx_files_subpaths = super ().post_process_exported_models (
@@ -119,8 +115,6 @@ def post_process_exported_models(
119115
120116 # Attempt to merge only if the decoder-only was exported separately without/with past
121117 if self .use_past is True and len (models_and_onnx_configs ) == 2 :
122- from optimum .onnx import merge_decoders
123-
124118 decoder_path = Path (path , onnx_files_subpaths [0 ])
125119 decoder_with_past_path = Path (path , onnx_files_subpaths [1 ])
126120 decoder_merged_path = Path (path , ONNX_DECODER_MERGED_NAME + ".onnx" )
@@ -171,35 +165,19 @@ class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast):
171165 DummySeq2SeqPastKeyValuesGenerator ,
172166 )
173167
174- @property
175- def torch_to_onnx_input_map (self ) -> dict [str , str ]:
176- if self ._behavior is ConfigBehavior .DECODER :
177- return {
178- "decoder_input_ids" : "input_ids" ,
179- "encoder_outputs" : "encoder_hidden_states" ,
180- "attention_mask" : "encoder_attention_mask" ,
181- }
182- return {}
183-
184168 @property
185169 def inputs (self ) -> dict [str , dict [int , str ]]:
186170 common_inputs = {}
187- if self ._behavior is not ConfigBehavior .DECODER :
171+ if self ._behavior in { ConfigBehavior . ENCODER , ConfigBehavior .MONOLITH } :
188172 common_inputs ["input_ids" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
189-
173+ else :
174+ common_inputs ["encoder_outputs" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
190175 common_inputs ["attention_mask" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
191176
192- if self ._behavior is not ConfigBehavior .ENCODER :
177+ if self ._behavior in {ConfigBehavior .DECODER , ConfigBehavior .MONOLITH }:
178+ common_inputs ["decoder_input_ids" ] = {0 : "batch_size" , 1 : "decoder_sequence_length" }
193179 if self .use_past_in_inputs :
194- # TODO: validate the axis name for attention_mask
195- # common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length"
196- common_inputs ["decoder_input_ids" ] = {0 : "batch_size" }
197180 self .add_past_key_values (common_inputs , direction = "inputs" )
198- else :
199- common_inputs ["decoder_input_ids" ] = {0 : "batch_size" , 1 : "decoder_sequence_length" }
200-
201- if self ._behavior is ConfigBehavior .DECODER :
202- common_inputs ["encoder_outputs" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
203181
204182 return common_inputs
205183
@@ -260,31 +238,18 @@ class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast):
260238 def inputs (self ) -> dict [str , dict [int , str ]]:
261239 common_inputs = {}
262240
263- if self ._behavior is not ConfigBehavior .DECODER :
241+ if self ._behavior in { ConfigBehavior . ENCODER , ConfigBehavior .MONOLITH } :
264242 common_inputs ["input_features" ] = {0 : "batch_size" , 1 : "feature_size" , 2 : "encoder_sequence_length" }
243+ else :
244+ common_inputs ["encoder_outputs" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
265245
266- if self ._behavior is not ConfigBehavior .ENCODER :
246+ if self ._behavior in {ConfigBehavior .DECODER , ConfigBehavior .MONOLITH }:
247+ common_inputs ["decoder_input_ids" ] = {0 : "batch_size" , 1 : "decoder_sequence_length" }
267248 if self .use_past_in_inputs :
268- common_inputs ["decoder_input_ids" ] = {0 : "batch_size" }
269249 self .add_past_key_values (common_inputs , direction = "inputs" )
270- else :
271- common_inputs ["decoder_input_ids" ] = {0 : "batch_size" , 1 : "decoder_sequence_length" }
272-
273- if self ._behavior is ConfigBehavior .DECODER :
274- common_inputs ["encoder_outputs" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
275250
276251 return common_inputs
277252
278- @property
279- def torch_to_onnx_input_map (self ) -> dict [str , str ]:
280- if self ._behavior is ConfigBehavior .DECODER :
281- return {
282- "decoder_input_ids" : "input_ids" ,
283- "encoder_outputs" : "encoder_hidden_states" ,
284- "attention_mask" : "encoder_attention_mask" ,
285- }
286- return {}
287-
288253
289254class EncoderDecoderBaseOnnxConfig (OnnxSeq2SeqConfigWithPast ):
290255 DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator ,)
@@ -313,8 +278,6 @@ def __init__(
313278 legacy = legacy ,
314279 )
315280
316- from optimum .exporters .tasks import TasksManager
317-
318281 self .is_decoder_with_past = False
319282
320283 # Set up the encoder ONNX config.
@@ -382,41 +345,19 @@ def __init__(
382345 @property
383346 def inputs (self ) -> dict [str , dict [int , str ]]:
384347 common_inputs = {}
385- if self ._behavior is not ConfigBehavior .DECODER :
348+ if self ._behavior in { ConfigBehavior . ENCODER , ConfigBehavior .MONOLITH } :
386349 common_inputs ["input_ids" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
387-
350+ else :
351+ common_inputs ["encoder_outputs" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
388352 common_inputs ["attention_mask" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
389353
390- if self ._behavior is not ConfigBehavior .ENCODER :
391- # TODO: it is likely this pop() is unwanted as we then always hit
392- # https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/models/t5/modeling_t5.py#L965-L969
393- common_inputs .pop ("attention_mask" )
394-
395- if self .use_past_in_inputs :
396- # TODO: validate the axis name for attention_mask
397- # common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length"
398- common_inputs ["decoder_input_ids" ] = {0 : "batch_size" }
399- else :
400- common_inputs ["decoder_input_ids" ] = {0 : "batch_size" , 1 : "decoder_sequence_length" }
401-
354+ if self ._behavior in {ConfigBehavior .DECODER , ConfigBehavior .MONOLITH }:
355+ common_inputs ["decoder_input_ids" ] = {0 : "batch_size" , 1 : "decoder_sequence_length" }
402356 if self .use_past_in_inputs :
403357 self .add_past_key_values (common_inputs , direction = "inputs" )
404358
405- if self ._behavior is ConfigBehavior .DECODER :
406- common_inputs ["encoder_outputs" ] = {0 : "batch_size" , 1 : "encoder_sequence_length" }
407-
408359 return common_inputs
409360
410- @property
411- def torch_to_onnx_input_map (self ) -> dict [str , str ]:
412- if self ._behavior is ConfigBehavior .DECODER :
413- return {
414- "decoder_input_ids" : "input_ids" ,
415- "encoder_outputs" : "encoder_hidden_states" ,
416- "attention_mask" : "encoder_attention_mask" ,
417- }
418- return {}
419-
420361 def add_past_key_values (self , inputs_or_outputs : dict [str , dict [int , str ]], direction : str ):
421362 if self .is_decoder_with_past :
422363 return self ._decoder_onnx_config .add_past_key_values (inputs_or_outputs , direction )
@@ -429,26 +370,34 @@ def flatten_output_collection_property(self, name: str, field: Iterable[Any]) ->
429370 return self ._decoder_onnx_config .flatten_output_collection_property (name , field )
430371
431372 def generate_dummy_inputs_for_validation (
432- self , reference_model_inputs : dict [str , Any ], onnx_input_names : list [str ] | None = None
373+ self , reference_model_inputs : dict [str , Any ], onnx_input_names : list [str ]
433374 ) -> dict [str , Any ]:
434375 if self ._behavior is ConfigBehavior .ENCODER :
435- return self ._encoder_onnx_config .generate_dummy_inputs_for_validation (reference_model_inputs )
376+ return self ._encoder_onnx_config .generate_dummy_inputs_for_validation (
377+ reference_model_inputs , onnx_input_names
378+ )
436379 else :
437380 if self ._behavior is ConfigBehavior .DECODER :
438- reference_model_inputs ["input_ids" ] = reference_model_inputs .pop ("decoder_input_ids" )
439-
440- if "encoder_outputs" in reference_model_inputs :
441- if "encoder_hidden_states" in onnx_input_names :
442- reference_model_inputs ["encoder_hidden_states" ] = reference_model_inputs .pop ("encoder_outputs" )[0 ]
443- else :
444- reference_model_inputs .pop ("encoder_outputs" )
445-
446- return self ._decoder_onnx_config .generate_dummy_inputs_for_validation (reference_model_inputs )
381+ if "decoder_input_ids" in reference_model_inputs :
382+ reference_model_inputs ["input_ids" ] = reference_model_inputs .pop ("decoder_input_ids" )
383+ if "attention_mask" in reference_model_inputs :
384+ reference_model_inputs ["encoder_attention_mask" ] = reference_model_inputs .pop ("attention_mask" )
385+ if "encoder_outputs" in reference_model_inputs :
386+ if "encoder_hidden_states" in onnx_input_names :
387+ reference_model_inputs ["encoder_hidden_states" ] = reference_model_inputs .pop (
388+ "encoder_outputs"
389+ )[0 ]
390+ else :
391+ reference_model_inputs .pop ("encoder_outputs" )
392+
393+ return self ._decoder_onnx_config .generate_dummy_inputs_for_validation (
394+ reference_model_inputs , onnx_input_names
395+ )
447396
448397 def post_process_exported_models (
449398 self ,
450399 path : Path ,
451- models_and_onnx_configs : dict [str , tuple [PreTrainedModel | ModelMixin , OnnxConfig ]],
400+ models_and_onnx_configs : dict [str , tuple [PreTrainedModel , OnnxConfig ]],
452401 onnx_files_subpaths : list [str ],
453402 ):
454403 models_and_onnx_configs , onnx_files_subpaths = super ().post_process_exported_models (
0 commit comments