Skip to content

Commit 5ab5bf3

Browse files
Transformers 4.54 and 4.55 and gpt-oss (#52)
fixes #55 and #46
1 parent c6e3326 commit 5ab5bf3

19 files changed

+2375
-1953
lines changed

.github/workflows/test_exporters_common.yml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ concurrency:
1111
cancel-in-progress: true
1212

1313
env:
14+
UV_SYSTEM_PYTHON: true
15+
UV_TORCH_BACKEND: auto
1416
TRANSFORMERS_IS_CI: true
17+
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
1518

1619
jobs:
1720
build:
1821
strategy:
1922
fail-fast: false
2023
matrix:
21-
runs-on: [ubuntu-22.04]
2224
python-version: [3.9]
25+
runs-on: [ubuntu-22.04]
2326

2427
runs-on: ${{ matrix.runs-on }}
2528

@@ -34,10 +37,9 @@ jobs:
3437

3538
- name: Install dependencies
3639
run: |
37-
pip install --upgrade pip
38-
pip install --no-cache-dir torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
39-
pip install .[tests,onnxruntime]
40+
pip install --upgrade pip uv
41+
uv pip install .[tests,onnxruntime]
4042
4143
- name: Test with pytest
4244
run: |
43-
pytest tests/exporters/common -vvvv --durations=0 -n auto
45+
pytest tests/exporters/common -vvvv -n auto

.github/workflows/test_exporters_onnx.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ concurrency:
1111
cancel-in-progress: true
1212

1313
env:
14+
UV_SYSTEM_PYTHON: true
15+
UV_TORCH_BACKEND: auto
1416
TRANSFORMERS_IS_CI: true
17+
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
1518

1619
jobs:
1720
build:
@@ -34,10 +37,9 @@ jobs:
3437

3538
- name: Install dependencies
3639
run: |
37-
pip install --upgrade pip
38-
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
39-
pip install .[tests,onnxruntime] diffusers
40+
pip install --upgrade pip uv
41+
uv pip install .[tests,onnxruntime] diffusers
4042
4143
- name: Test with pytest
4244
run: |
43-
pytest tests/exporters/onnx/test_export.py -vvvv --durations=0 -n auto
45+
pytest tests/exporters/onnx/test_export.py -vvvv -n auto

.github/workflows/test_exporters_onnx_cli.yml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@ concurrency:
1111
cancel-in-progress: true
1212

1313
env:
14+
UV_SYSTEM_PYTHON: true
15+
UV_TORCH_BACKEND: auto
1416
TRANSFORMERS_IS_CI: true
17+
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
1518

1619
jobs:
1720
build:
1821
strategy:
1922
fail-fast: false
2023
matrix:
2124
python-version: [3.9]
22-
os: [ubuntu-22.04]
25+
runs-on: [ubuntu-22.04]
2326

24-
runs-on: ${{ matrix.os }}
27+
runs-on: ${{ matrix.runs-on }}
2528

2629
steps:
2730
- name: Checkout repository
@@ -34,10 +37,9 @@ jobs:
3437

3538
- name: Install dependencies
3639
run: |
37-
pip install --upgrade pip
38-
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
39-
pip install .[tests,onnxruntime] diffusers
40+
pip install --upgrade pip uv
41+
uv pip install .[tests,onnxruntime] diffusers
4042
4143
- name: Test with pytest
4244
run: |
43-
pytest tests/exporters/onnx/test_export_cli.py -vvvv --durations=0 -n auto
45+
pytest tests/exporters/onnx/test_export_cli.py -vvvv -n auto

.github/workflows/test_onnxruntime.yml

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ concurrency:
1212
cancel-in-progress: true
1313

1414
env:
15+
UV_SYSTEM_PYTHON: true
16+
UV_TORCH_BACKEND: auto
1517
TRANSFORMERS_IS_CI: true
18+
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
1619

1720
jobs:
1821
build:
@@ -24,22 +27,19 @@ jobs:
2427
transformers_version: [latest, 4.36.*, 4.45.*]
2528
test_file:
2629
[
27-
test_timm.py,
2830
test_decoder.py,
29-
test_modeling.py,
3031
test_diffusion.py,
32+
test_modeling.py,
3133
test_optimization.py,
3234
test_quantization.py,
35+
test_seq2seq.py,
36+
test_timm.py,
3337
test_utils.py,
3438
]
3539

3640
runs-on: ${{ matrix.runs-on }}
3741

3842
steps:
39-
- name: Free Disk Space (Ubuntu)
40-
if: matrix.test_file == 'test_modeling.py'
41-
uses: jlumbroso/free-disk-space@main
42-
4343
- name: Checkout code
4444
uses: actions/checkout@v5
4545

@@ -50,20 +50,29 @@ jobs:
5050

5151
- name: Install dependencies
5252
run: |
53-
pip install --upgrade pip
54-
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
55-
pip install .[tests,onnxruntime] diffusers
53+
pip install --upgrade pip uv
54+
uv pip install .[tests,onnxruntime]
55+
56+
- name: Install diffusers for transformers ${{ matrix.transformers-version }}
57+
run: |
58+
if [ "${{ matrix.transformers_version }}" == '4.36.*' ]; then
59+
uv pip install "diffusers<0.32.0"
60+
elif [ "${{ matrix.transformers_version }}" == '4.45.*' ]; then
61+
uv pip install "diffusers<0.33.0"
62+
else
63+
uv pip install diffusers
64+
fi
5665
5766
- name: Install transformers ${{ matrix.transformers-version }}
5867
run: |
5968
if [ "${{ matrix.transformers_version }}" == '4.36.*' ]; then
60-
pip install "transformers==4.36.*" "diffusers<0.32.0" "pytest<8.0.0"
69+
uv pip install "transformers==4.36.*" "pytest<8.0.0"
6170
elif [ "${{ matrix.transformers_version }}" == '4.45.*' ]; then
62-
pip install "transformers==4.45.*" "diffusers<0.33.0"
71+
uv pip install "transformers==4.45.*"
72+
elif [ "${{ matrix.transformers_version }}" != 'latest' ]; then
73+
uv pip install "transformers==${{ matrix.transformers_version }}"
6374
fi
6475
6576
- name: Test with pytest
6677
run: |
67-
pytest tests/onnxruntime/${{ matrix.test_file }} --durations=0 -vvvv -n auto
68-
env:
69-
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
78+
pytest tests/onnxruntime/${{ matrix.test_file }} -vvvv -n auto

optimum/exporters/onnx/_traceable_cache.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
from __future__ import annotations
215

316
import logging
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import warnings
15+
from collections import defaultdict
16+
from functools import wraps
17+
18+
from transformers.utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder, logger
19+
20+
21+
# This is a fixed version of transformers.utils.generic.check_model_inputs
22+
# that fixes issues related to onnx export and tracing
23+
# - adds support for positional args (use_cache), without which use_cache end up being passed twice
24+
# - fixes issue with default capture_flags being None for some models
25+
def traceable_check_model_inputs(func):
26+
@wraps(func)
27+
def wrapper(self, *args, **kwargs):
28+
use_cache = (
29+
kwargs["use_cache"] if kwargs.get("use_cache") is not None else getattr(self.config, "use_cache", None)
30+
)
31+
if use_cache is not None:
32+
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
33+
logger.warning_once(
34+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
35+
)
36+
use_cache = False
37+
38+
# Prevent passing use_cache twice
39+
if "use_cache" in func.__code__.co_varnames:
40+
use_cache_idx = func.__code__.co_varnames.index("use_cache") - 1 # minus 1 for 'self'
41+
if len(args) > use_cache_idx:
42+
args = list(args)
43+
args[use_cache_idx] = use_cache
44+
args = tuple(args)
45+
else:
46+
kwargs["use_cache"] = use_cache
47+
48+
return_dict = kwargs.pop("return_dict", None)
49+
if return_dict is None:
50+
return_dict = getattr(self.config, "return_dict", True)
51+
52+
all_args = kwargs.copy()
53+
if "kwargs" in all_args:
54+
for k, v in all_args["kwargs"].items():
55+
all_args[k] = v
56+
57+
capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__)) or {} # there is a weak ref for executorch
58+
59+
recordable_keys = {
60+
f"output_{k}": all_args.get(
61+
f"output_{k}",
62+
getattr(
63+
self.config,
64+
f"output_{k}",
65+
all_args.get("output_attentions", getattr(self.config, "output_attentions", False)),
66+
),
67+
)
68+
for k in capture_flags
69+
}
70+
71+
# We let cross attentions to be saved separately because some models add `cross-attn` layer
72+
# when certain condtions are met. Let's output cross attention if attentions are requested (for BC)
73+
if "output_attentions" in recordable_keys:
74+
recordable_keys["output_cross_attentions"] = recordable_keys["output_attentions"]
75+
76+
collected_outputs = defaultdict(tuple)
77+
monkey_patched_layers = []
78+
79+
# Check attention implementation is properly set for capturing attention outputs
80+
if recordable_keys.get("output_attentions", False):
81+
supported_attn = ["eager", "eager_paged", "flex_attention"]
82+
config_attn = getattr(self.config, "_attn_implementation", None)
83+
sub_configs = [getattr(self.config, key, None) for key in self.config.sub_configs]
84+
sub_configs_attn = [
85+
getattr(config, "_attn_implementation", None) for config in sub_configs if config is not None
86+
]
87+
if config_attn not in supported_attn or any(attn not in supported_attn for attn in sub_configs_attn):
88+
warnings.warn(
89+
f"`output_attentions=True` is not supported with `attn_implementation` other than {supported_attn}. "
90+
"Please use `model.set_attn_implementation('eager')` to enable capturing attention outputs.",
91+
UserWarning,
92+
stacklevel=2,
93+
)
94+
95+
def make_capture_wrapper(module, orig_forward, key, index):
96+
@wraps(orig_forward)
97+
def wrapped_forward(*args, **kwargs):
98+
if key == "hidden_states" and len(collected_outputs[key]) == 0:
99+
collected_outputs[key] += (args[0],)
100+
output = orig_forward(*args, **kwargs)
101+
if not isinstance(output, tuple):
102+
collected_outputs[key] += (output,)
103+
elif output[index] is not None:
104+
if key not in collected_outputs:
105+
collected_outputs[key] = (output[index],)
106+
else:
107+
collected_outputs[key] += (output[index],)
108+
return output
109+
110+
return wrapped_forward
111+
112+
if any(recordable_keys.values()):
113+
capture_tasks = []
114+
for key, layer_specs in capture_flags.items():
115+
if not recordable_keys.get(f"output_{key}", False):
116+
continue
117+
if not isinstance(layer_specs, list):
118+
layer_specs = [layer_specs]
119+
for specs in layer_specs:
120+
if not isinstance(specs, OutputRecorder):
121+
index = 0 if "hidden_states" in key else 1
122+
class_name = None if not isinstance(specs, str) else specs
123+
target_class = specs if not isinstance(specs, str) else None
124+
specs = OutputRecorder(target_class=target_class, index=index, class_name=class_name)
125+
capture_tasks.append((key, specs))
126+
127+
for name, module in self.named_modules():
128+
for key, specs in capture_tasks:
129+
# The second check is for multimodals where only backbone layer suffix is available
130+
if (specs.target_class is not None and isinstance(module, specs.target_class)) or (
131+
specs.class_name is not None and name.endswith(specs.class_name)
132+
):
133+
if specs.layer_name is not None and specs.layer_name not in name:
134+
continue
135+
# Monkey patch forward
136+
original_forward = module.forward
137+
module.forward = make_capture_wrapper(module, original_forward, key, specs.index)
138+
monkey_patched_layers.append((module, original_forward))
139+
140+
outputs = func(self, *args, **kwargs)
141+
# Restore original forward methods
142+
for module, original_forward in monkey_patched_layers:
143+
module.forward = original_forward
144+
145+
# Inject collected outputs into model output
146+
for key in collected_outputs:
147+
if key == "hidden_states":
148+
if hasattr(outputs, "vision_hidden_states"):
149+
collected_outputs[key] = collected_outputs[key][:-1]
150+
collected_outputs[key] += (outputs.vision_hidden_states,)
151+
elif hasattr(outputs, "last_hidden_state"):
152+
collected_outputs[key] = collected_outputs[key][:-1]
153+
collected_outputs[key] += (outputs.last_hidden_state,)
154+
155+
outputs[key] = collected_outputs[key]
156+
elif key == "attentions":
157+
if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2:
158+
outputs[key] = collected_outputs[key][0::2]
159+
outputs["cross_" + key] = collected_outputs[key][1::2]
160+
else:
161+
outputs[key] = collected_outputs[key]
162+
else:
163+
outputs[key] = collected_outputs[key]
164+
if return_dict is False:
165+
outputs = outputs.to_tuple()
166+
return outputs
167+
168+
return wrapper

0 commit comments

Comments
 (0)