Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/features/model-splitting.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Number splits

**`cost-model`**

Let's now split the model using a cost model. Please refer to the [pre-generated cost models](https://github.com/microsoft/Olive/blob/main/assets/cost_models/Phi-3.5-mini.csv) in the Olive repository for an example a cost model csv.
Let's now split the model using a cost model. Please refer to the [pre-generated cost models](https://github.com/microsoft/Olive/blob/main/olive/assets/cost_models/Phi-3.5-mini.csv) in the Olive repository for an example a cost model csv.

```bash
olive auto-opt -m microsoft/Phi-3.5-mini-instruct --precision fp16 --provider CUDAExecutionProvider --memory 2GB --cost-model phi-3.5-cost.csv -o models/phi-costsplit
Expand Down
14 changes: 10 additions & 4 deletions olive/common/hf/io_config/task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_io_config(
_add_present_outputs(dynamic_axes, config)

# Order inputs according to model forward signature
ordered_inputs = _order_inputs(dynamic_axes, model)
ordered_inputs = _order_inputs(dynamic_axes, model, set(outputs.keys()))

# Separate input and output names
input_names = [name for name in ordered_inputs if not name.startswith("present.")]
Expand Down Expand Up @@ -202,6 +202,7 @@ def _add_present_outputs(
def _order_inputs(
dynamic_axes: dict,
model: PreTrainedModel | None,
output_names: set[str] | None = None,
) -> OrderedDict:
"""Order inputs according to model forward signature.

Expand All @@ -210,18 +211,23 @@ def _order_inputs(
Args:
dynamic_axes: Dict of all dynamic axes (inputs and outputs).
model: Optional model for forward signature inspection.
output_names: Set of output names to exclude from input ordering.

Returns:
OrderedDict of input names to dynamic axes, ordered by forward signature.

"""
import re

# Filter to only input names (not outputs like present.*)
if output_names is None:
output_names = set()

# Filter to only input names (not outputs like present.* or explicit output names)
input_axes = OrderedDict()
for name, axes in dynamic_axes.items():
if not name.startswith("present.") and not name.startswith("logits"):
input_axes[name] = axes
if name.startswith("present.") or name in output_names:
continue
input_axes[name] = axes

if model is None:
return input_axes
Expand Down
8 changes: 8 additions & 0 deletions olive/common/hf/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,15 @@ def get_model_io_config(
def get_model_dummy_input(
model_name: str,
task: str,
model: Optional["PreTrainedModel"] = None,
**kwargs,
) -> Optional[dict[str, Any]]:
"""Get dummy inputs for the model and task.

Args:
model_name: The model name or path.
task: The task type.
model: Optional loaded model for input signature inspection.
**kwargs: Additional arguments including use_cache, batch_size, sequence_length.

Returns:
Expand Down Expand Up @@ -133,10 +135,16 @@ def get_model_dummy_input(
# Get model config (handles MLflow paths)
model_config = get_model_config(model_name, **kwargs)

# Handle PEFT models
actual_model = model
if model is not None and is_peft_model(model):
actual_model = model.get_base_model()

try:
return generate_dummy_inputs(
model_name_or_config=model_config,
task=actual_task,
model=actual_model,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
)
Expand Down
1 change: 1 addition & 0 deletions olive/model/handler/mixin/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def get_hf_dummy_inputs(self) -> Optional[dict[str, Any]]:
return get_model_dummy_input(
self.model_path,
self.task,
model=self.load_model(),
**self.get_load_kwargs(),
)

Expand Down
2 changes: 1 addition & 1 deletion olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, dummy_kwargs, model):
# Align tree spec only for not transformers.Cache.
if len(dummy_inputs) == 0:
for k, v in dummy_kwargs.items():
if not isinstance(v, transformers.Cache):
if not isinstance(v, transformers.Cache) and k in dynamic_shapes:
input_tree_spec = _pytree.tree_flatten(v)[1]
flatten_dynamic_shapes = get_the_flattened_and_tree_spec(dynamic_shapes[k], leaf_is_str=False)[0]
dynamic_shapes[k] = _pytree.tree_unflatten(flatten_dynamic_shapes, input_tree_spec)
Expand Down
4 changes: 2 additions & 2 deletions test/model/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------
import json
from pathlib import Path
from unittest.mock import patch
from unittest.mock import ANY, patch

import huggingface_hub
import pytest
Expand Down Expand Up @@ -222,5 +222,5 @@ def test_hf_onnx_config_dummy_inputs(self, get_model_dummy_input):
# get dummy inputs
dummy_inputs = olive_model.get_dummy_inputs()

get_model_dummy_input.assert_called_once_with(self.model_name, self.task)
get_model_dummy_input.assert_called_once_with(self.model_name, self.task, model=ANY)
assert dummy_inputs == 1
Loading