Skip to content

Commit fafa2a7

Browse files
ORT Seq2Seq modeling refactorization (#34)
This PR: - Creates and separates the testing of seq2seq / encoder-decoder models (text, speech and vision). - Enables full testing, including batched inference with padded inputs. logits, last hidden state and pkv matching assertions. and beam search, random sampling... - Fixes issues emerging/detected when running the full testing suite, like missing dynamic axis, missing attention masks in the encoder/decoder, untested model types for some tasks, unpatched models, etc.. - Rewrites and simplifies the seq2seq modeling code, removing a lot of hackiness around pkv (self attn vs cross attn) and making the forward pass (especially the seq2seqdecoder forward) more linear, finally enabling cases like use_merged=True and use_cache=False, which weren't supported before. - Rewrite and simplifies the filename inference when loading a seq2seq model. - Documents behaviors like m2m100 consuming the RNG state, bigbird pegasus switching between attention types. - Enables Moonshine speech recognition and patch it for dynamic audio input length (only supported static). - Removes some deprecated stuff, like onnx config with loss which was only relevant for ort training. - Correct the output dynamic axis for different ctc models. - [maybe more 🥲]
1 parent 4f78a83 commit fafa2a7

23 files changed

+2832
-4526
lines changed

optimum/exporters/onnx/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
_import_structure = {
23-
"base": ["OnnxConfig", "OnnxConfigWithLoss", "OnnxConfigWithPast", "OnnxSeq2SeqConfigWithPast"],
23+
"base": ["OnnxConfig", "OnnxConfigWithPast", "OnnxSeq2SeqConfigWithPast"],
2424
"config": ["TextDecoderOnnxConfig", "TextEncoderOnnxConfig", "TextSeq2SeqOnnxConfig"],
2525
"convert": [
2626
"export",
@@ -40,12 +40,7 @@
4040

4141
if TYPE_CHECKING:
4242
from optimum.exporters.onnx.__main__ import main_export
43-
from optimum.exporters.onnx.base import (
44-
OnnxConfig,
45-
OnnxConfigWithLoss,
46-
OnnxConfigWithPast,
47-
OnnxSeq2SeqConfigWithPast,
48-
)
43+
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
4944
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig
5045
from optimum.exporters.onnx.convert import (
5146
export,

optimum/exporters/onnx/base.py

Lines changed: 48 additions & 197 deletions
Large diffs are not rendered by default.

optimum/exporters/onnx/config.py

Lines changed: 38 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
from optimum.exporters.onnx.base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
2424
from optimum.exporters.onnx.constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
25+
from optimum.exporters.tasks import TasksManager
26+
from optimum.onnx import merge_decoders
2527
from optimum.utils import (
2628
DummyAudioInputGenerator,
2729
DummyBboxInputGenerator,
@@ -31,19 +33,13 @@
3133
DummySeq2SeqPastKeyValuesGenerator,
3234
DummyTextInputGenerator,
3335
DummyVisionInputGenerator,
34-
is_diffusers_available,
3536
logging,
3637
)
3738

3839

39-
# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization
40-
41-
4240
if TYPE_CHECKING:
4341
from transformers import PretrainedConfig, PreTrainedModel
4442

45-
if is_diffusers_available():
46-
from diffusers import ModelMixin
4743

4844
logger = logging.get_logger(__name__)
4945

@@ -110,7 +106,7 @@ def outputs(self) -> dict[str, dict[int, str]]:
110106
def post_process_exported_models(
111107
self,
112108
path: Path,
113-
models_and_onnx_configs: dict[str, tuple[PreTrainedModel | ModelMixin, OnnxConfig]],
109+
models_and_onnx_configs: dict[str, tuple[PreTrainedModel, OnnxConfig]],
114110
onnx_files_subpaths: list[str],
115111
):
116112
models_and_onnx_configs, onnx_files_subpaths = super().post_process_exported_models(
@@ -119,8 +115,6 @@ def post_process_exported_models(
119115

120116
# Attempt to merge only if the decoder-only was exported separately without/with past
121117
if self.use_past is True and len(models_and_onnx_configs) == 2:
122-
from optimum.onnx import merge_decoders
123-
124118
decoder_path = Path(path, onnx_files_subpaths[0])
125119
decoder_with_past_path = Path(path, onnx_files_subpaths[1])
126120
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
@@ -171,35 +165,19 @@ class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast):
171165
DummySeq2SeqPastKeyValuesGenerator,
172166
)
173167

174-
@property
175-
def torch_to_onnx_input_map(self) -> dict[str, str]:
176-
if self._behavior is ConfigBehavior.DECODER:
177-
return {
178-
"decoder_input_ids": "input_ids",
179-
"encoder_outputs": "encoder_hidden_states",
180-
"attention_mask": "encoder_attention_mask",
181-
}
182-
return {}
183-
184168
@property
185169
def inputs(self) -> dict[str, dict[int, str]]:
186170
common_inputs = {}
187-
if self._behavior is not ConfigBehavior.DECODER:
171+
if self._behavior in {ConfigBehavior.ENCODER, ConfigBehavior.MONOLITH}:
188172
common_inputs["input_ids"] = {0: "batch_size", 1: "encoder_sequence_length"}
189-
173+
else:
174+
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
190175
common_inputs["attention_mask"] = {0: "batch_size", 1: "encoder_sequence_length"}
191176

192-
if self._behavior is not ConfigBehavior.ENCODER:
177+
if self._behavior in {ConfigBehavior.DECODER, ConfigBehavior.MONOLITH}:
178+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
193179
if self.use_past_in_inputs:
194-
# TODO: validate the axis name for attention_mask
195-
# common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length"
196-
common_inputs["decoder_input_ids"] = {0: "batch_size"}
197180
self.add_past_key_values(common_inputs, direction="inputs")
198-
else:
199-
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
200-
201-
if self._behavior is ConfigBehavior.DECODER:
202-
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
203181

204182
return common_inputs
205183

@@ -260,31 +238,18 @@ class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast):
260238
def inputs(self) -> dict[str, dict[int, str]]:
261239
common_inputs = {}
262240

263-
if self._behavior is not ConfigBehavior.DECODER:
241+
if self._behavior in {ConfigBehavior.ENCODER, ConfigBehavior.MONOLITH}:
264242
common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}
243+
else:
244+
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
265245

266-
if self._behavior is not ConfigBehavior.ENCODER:
246+
if self._behavior in {ConfigBehavior.DECODER, ConfigBehavior.MONOLITH}:
247+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
267248
if self.use_past_in_inputs:
268-
common_inputs["decoder_input_ids"] = {0: "batch_size"}
269249
self.add_past_key_values(common_inputs, direction="inputs")
270-
else:
271-
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
272-
273-
if self._behavior is ConfigBehavior.DECODER:
274-
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
275250

276251
return common_inputs
277252

278-
@property
279-
def torch_to_onnx_input_map(self) -> dict[str, str]:
280-
if self._behavior is ConfigBehavior.DECODER:
281-
return {
282-
"decoder_input_ids": "input_ids",
283-
"encoder_outputs": "encoder_hidden_states",
284-
"attention_mask": "encoder_attention_mask",
285-
}
286-
return {}
287-
288253

289254
class EncoderDecoderBaseOnnxConfig(OnnxSeq2SeqConfigWithPast):
290255
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,)
@@ -313,8 +278,6 @@ def __init__(
313278
legacy=legacy,
314279
)
315280

316-
from optimum.exporters.tasks import TasksManager
317-
318281
self.is_decoder_with_past = False
319282

320283
# Set up the encoder ONNX config.
@@ -382,41 +345,19 @@ def __init__(
382345
@property
383346
def inputs(self) -> dict[str, dict[int, str]]:
384347
common_inputs = {}
385-
if self._behavior is not ConfigBehavior.DECODER:
348+
if self._behavior in {ConfigBehavior.ENCODER, ConfigBehavior.MONOLITH}:
386349
common_inputs["input_ids"] = {0: "batch_size", 1: "encoder_sequence_length"}
387-
350+
else:
351+
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
388352
common_inputs["attention_mask"] = {0: "batch_size", 1: "encoder_sequence_length"}
389353

390-
if self._behavior is not ConfigBehavior.ENCODER:
391-
# TODO: it is likely this pop() is unwanted as we then always hit
392-
# https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/models/t5/modeling_t5.py#L965-L969
393-
common_inputs.pop("attention_mask")
394-
395-
if self.use_past_in_inputs:
396-
# TODO: validate the axis name for attention_mask
397-
# common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length"
398-
common_inputs["decoder_input_ids"] = {0: "batch_size"}
399-
else:
400-
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
401-
354+
if self._behavior in {ConfigBehavior.DECODER, ConfigBehavior.MONOLITH}:
355+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
402356
if self.use_past_in_inputs:
403357
self.add_past_key_values(common_inputs, direction="inputs")
404358

405-
if self._behavior is ConfigBehavior.DECODER:
406-
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
407-
408359
return common_inputs
409360

410-
@property
411-
def torch_to_onnx_input_map(self) -> dict[str, str]:
412-
if self._behavior is ConfigBehavior.DECODER:
413-
return {
414-
"decoder_input_ids": "input_ids",
415-
"encoder_outputs": "encoder_hidden_states",
416-
"attention_mask": "encoder_attention_mask",
417-
}
418-
return {}
419-
420361
def add_past_key_values(self, inputs_or_outputs: dict[str, dict[int, str]], direction: str):
421362
if self.is_decoder_with_past:
422363
return self._decoder_onnx_config.add_past_key_values(inputs_or_outputs, direction)
@@ -429,26 +370,34 @@ def flatten_output_collection_property(self, name: str, field: Iterable[Any]) ->
429370
return self._decoder_onnx_config.flatten_output_collection_property(name, field)
430371

431372
def generate_dummy_inputs_for_validation(
432-
self, reference_model_inputs: dict[str, Any], onnx_input_names: list[str] | None = None
373+
self, reference_model_inputs: dict[str, Any], onnx_input_names: list[str]
433374
) -> dict[str, Any]:
434375
if self._behavior is ConfigBehavior.ENCODER:
435-
return self._encoder_onnx_config.generate_dummy_inputs_for_validation(reference_model_inputs)
376+
return self._encoder_onnx_config.generate_dummy_inputs_for_validation(
377+
reference_model_inputs, onnx_input_names
378+
)
436379
else:
437380
if self._behavior is ConfigBehavior.DECODER:
438-
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
439-
440-
if "encoder_outputs" in reference_model_inputs:
441-
if "encoder_hidden_states" in onnx_input_names:
442-
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
443-
else:
444-
reference_model_inputs.pop("encoder_outputs")
445-
446-
return self._decoder_onnx_config.generate_dummy_inputs_for_validation(reference_model_inputs)
381+
if "decoder_input_ids" in reference_model_inputs:
382+
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
383+
if "attention_mask" in reference_model_inputs:
384+
reference_model_inputs["encoder_attention_mask"] = reference_model_inputs.pop("attention_mask")
385+
if "encoder_outputs" in reference_model_inputs:
386+
if "encoder_hidden_states" in onnx_input_names:
387+
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop(
388+
"encoder_outputs"
389+
)[0]
390+
else:
391+
reference_model_inputs.pop("encoder_outputs")
392+
393+
return self._decoder_onnx_config.generate_dummy_inputs_for_validation(
394+
reference_model_inputs, onnx_input_names
395+
)
447396

448397
def post_process_exported_models(
449398
self,
450399
path: Path,
451-
models_and_onnx_configs: dict[str, tuple[PreTrainedModel | ModelMixin, OnnxConfig]],
400+
models_and_onnx_configs: dict[str, tuple[PreTrainedModel, OnnxConfig]],
452401
onnx_files_subpaths: list[str],
453402
):
454403
models_and_onnx_configs, onnx_files_subpaths = super().post_process_exported_models(

optimum/exporters/onnx/convert.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,7 @@ def _run_validation(
333333

334334
# Possibly edit the input for the onnxruntime.InferenceSession, this is for example the case for merged
335335
# models where the input `use_cache_branch` is added
336-
reference_ort_inputs = config.generate_dummy_inputs_for_validation(
337-
reference_model_inputs, onnx_input_names=onnx_input_names
338-
)
336+
reference_ort_inputs = config.generate_dummy_inputs_for_validation(reference_model_inputs, onnx_input_names)
339337

340338
# We flatten potential collection of inputs (i.e. past_keys)
341339
onnx_inputs = {}

optimum/exporters/onnx/input_generators.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from optimum.utils import DummyPastKeyValuesGenerator, NormalizedTextConfig, is_transformers_version
16+
from optimum.utils import (
17+
DummyAudioInputGenerator,
18+
DummyPastKeyValuesGenerator,
19+
NormalizedTextConfig,
20+
is_transformers_version,
21+
)
1722

1823

1924
class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
@@ -64,3 +69,25 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
6469
]
6570

6671
return pkv
72+
73+
74+
class DummyMoonshineAudioInputGenerator(DummyAudioInputGenerator):
75+
SUPPORTED_INPUT_NAMES = ("input_values", "attention_mask")
76+
77+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
78+
if input_name == "input_values": # raw waveform
79+
return self.random_float_tensor(
80+
shape=[self.batch_size, self.sequence_length],
81+
min_value=-1,
82+
max_value=1,
83+
framework=framework,
84+
dtype=float_dtype,
85+
)
86+
elif input_name == "attention_mask": # attention mask
87+
return self.random_mask_tensor(
88+
shape=[self.batch_size, self.sequence_length],
89+
framework=framework,
90+
dtype=int_dtype,
91+
)
92+
else:
93+
raise ValueError(f"Unsupported input name: {input_name}")

0 commit comments

Comments
 (0)