Skip to content

Commit d81b9b0

Browse files
authored
Fix dummy_inputs and dynamic_shapes inconsistency for optional inputs (#2317)
## Describe your changes Fix dummy_inputs and dynamic_shapes inconsistency for optional inputs ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
1 parent ca37f23 commit d81b9b0

File tree

6 files changed

+23
-8
lines changed

6 files changed

+23
-8
lines changed

docs/source/features/model-splitting.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Number splits
3737

3838
**`cost-model`**
3939

40-
Let's now split the model using a cost model. Please refer to the [pre-generated cost models](https://github.com/microsoft/Olive/blob/main/assets/cost_models/Phi-3.5-mini.csv) in the Olive repository for an example a cost model csv.
40+
Let's now split the model using a cost model. Please refer to the [pre-generated cost models](https://github.com/microsoft/Olive/blob/main/olive/assets/cost_models/Phi-3.5-mini.csv) in the Olive repository for an example a cost model csv.
4141

4242
```bash
4343
olive auto-opt -m microsoft/Phi-3.5-mini-instruct --precision fp16 --provider CUDAExecutionProvider --memory 2GB --cost-model phi-3.5-cost.csv -o models/phi-costsplit

olive/common/hf/io_config/task_config.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def get_io_config(
8484
_add_present_outputs(dynamic_axes, config)
8585

8686
# Order inputs according to model forward signature
87-
ordered_inputs = _order_inputs(dynamic_axes, model)
87+
ordered_inputs = _order_inputs(dynamic_axes, model, set(outputs.keys()))
8888

8989
# Separate input and output names
9090
input_names = [name for name in ordered_inputs if not name.startswith("present.")]
@@ -202,6 +202,7 @@ def _add_present_outputs(
202202
def _order_inputs(
203203
dynamic_axes: dict,
204204
model: PreTrainedModel | None,
205+
output_names: set[str] | None = None,
205206
) -> OrderedDict:
206207
"""Order inputs according to model forward signature.
207208
@@ -210,18 +211,23 @@ def _order_inputs(
210211
Args:
211212
dynamic_axes: Dict of all dynamic axes (inputs and outputs).
212213
model: Optional model for forward signature inspection.
214+
output_names: Set of output names to exclude from input ordering.
213215
214216
Returns:
215217
OrderedDict of input names to dynamic axes, ordered by forward signature.
216218
217219
"""
218220
import re
219221

220-
# Filter to only input names (not outputs like present.*)
222+
if output_names is None:
223+
output_names = set()
224+
225+
# Filter to only input names (not outputs like present.* or explicit output names)
221226
input_axes = OrderedDict()
222227
for name, axes in dynamic_axes.items():
223-
if not name.startswith("present.") and not name.startswith("logits"):
224-
input_axes[name] = axes
228+
if name.startswith("present.") or name in output_names:
229+
continue
230+
input_axes[name] = axes
225231

226232
if model is None:
227233
return input_axes

olive/common/hf/model_io.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,15 @@ def get_model_io_config(
9191
def get_model_dummy_input(
9292
model_name: str,
9393
task: str,
94+
model: Optional["PreTrainedModel"] = None,
9495
**kwargs,
9596
) -> Optional[dict[str, Any]]:
9697
"""Get dummy inputs for the model and task.
9798
9899
Args:
99100
model_name: The model name or path.
100101
task: The task type.
102+
model: Optional loaded model for input signature inspection.
101103
**kwargs: Additional arguments including use_cache, batch_size, sequence_length.
102104
103105
Returns:
@@ -133,10 +135,16 @@ def get_model_dummy_input(
133135
# Get model config (handles MLflow paths)
134136
model_config = get_model_config(model_name, **kwargs)
135137

138+
# Handle PEFT models
139+
actual_model = model
140+
if model is not None and is_peft_model(model):
141+
actual_model = model.get_base_model()
142+
136143
try:
137144
return generate_dummy_inputs(
138145
model_name_or_config=model_config,
139146
task=actual_task,
147+
model=actual_model,
140148
use_past=use_past,
141149
use_past_in_inputs=use_past_in_inputs,
142150
)

olive/model/handler/mixin/hf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def get_hf_dummy_inputs(self) -> Optional[dict[str, Any]]:
121121
return get_model_dummy_input(
122122
self.model_path,
123123
self.task,
124+
model=self.load_model(),
124125
**self.get_load_kwargs(),
125126
)
126127

olive/passes/onnx/conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, dummy_kwargs, model):
853853
# Align tree spec only for not transformers.Cache.
854854
if len(dummy_inputs) == 0:
855855
for k, v in dummy_kwargs.items():
856-
if not isinstance(v, transformers.Cache):
856+
if not isinstance(v, transformers.Cache) and k in dynamic_shapes:
857857
input_tree_spec = _pytree.tree_flatten(v)[1]
858858
flatten_dynamic_shapes = get_the_flattened_and_tree_spec(dynamic_shapes[k], leaf_is_str=False)[0]
859859
dynamic_shapes[k] = _pytree.tree_unflatten(flatten_dynamic_shapes, input_tree_spec)

test/model/test_hf_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# --------------------------------------------------------------------------
55
import json
66
from pathlib import Path
7-
from unittest.mock import patch
7+
from unittest.mock import ANY, patch
88

99
import huggingface_hub
1010
import pytest
@@ -222,5 +222,5 @@ def test_hf_onnx_config_dummy_inputs(self, get_model_dummy_input):
222222
# get dummy inputs
223223
dummy_inputs = olive_model.get_dummy_inputs()
224224

225-
get_model_dummy_input.assert_called_once_with(self.model_name, self.task)
225+
get_model_dummy_input.assert_called_once_with(self.model_name, self.task, model=ANY)
226226
assert dummy_inputs == 1

0 commit comments

Comments
 (0)