Skip to content

Commit bb6a397

Browse files
authored
[https://nvbugs/5732942][fix] AutoDeploy: handle transformers 4.57.1 upgrade fixes (#10466)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 00355b2 commit bb6a397

File tree

6 files changed

+93
-36
lines changed

6 files changed

+93
-36
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Patch to enable torch.nonzero() on meta tensors during export.
2+
3+
This patch addresses an issue where torch.nonzero() raises NotImplementedError
4+
when tracing models that use nonzero on meta device. The fix sets the config
5+
flag to assume all elements are non-zero, which enables export to proceed.
6+
"""
7+
8+
import torch.fx.experimental._config as fx_config
9+
10+
from ..interface import BaseExportPatch, ExportPatchRegistry
11+
12+
13+
@ExportPatchRegistry.register("meta_nonzero")
14+
class MetaNonzeroPatch(BaseExportPatch):
15+
"""Patch to enable torch.nonzero() meta registration during export.
16+
17+
This patch sets torch.fx.experimental._config.meta_nonzero_assume_all_nonzero
18+
to True, allowing torch.nonzero() to work on meta tensors during tracing.
19+
The implementation assumes all elements are non-zero, which is acceptable
20+
for tracing purposes where only shapes matter.
21+
"""
22+
23+
def _apply_patch(self):
24+
"""Apply the meta nonzero patch."""
25+
# Store original config value
26+
self.original_values["meta_nonzero_assume_all_nonzero"] = getattr(
27+
fx_config, "meta_nonzero_assume_all_nonzero", False
28+
)
29+
30+
# Enable nonzero on meta tensors
31+
fx_config.meta_nonzero_assume_all_nonzero = True
32+
33+
def _revert_patch(self):
34+
"""Revert the meta nonzero patch."""
35+
fx_config.meta_nonzero_assume_all_nonzero = self.original_values[
36+
"meta_nonzero_assume_all_nonzero"
37+
]

tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def _forward_with_cond(
7272
inputs_embeds = self.get_input_embeddings()(input_ids)
7373

7474
def _vision_branch(inputs_embeds, pixel_values, input_ids):
75+
# Updated to match transformers 4.57.1+ signature
76+
# get_image_features now takes: (self, pixel_values, vision_feature_select_strategy, **kwargs)
7577
image_features = self.get_image_features(
7678
pixel_values=pixel_values,
77-
vision_feature_layer=vision_feature_layer,
7879
vision_feature_select_strategy=vision_feature_select_strategy,
79-
image_sizes=None,
8080
)
8181

8282
vision_flat = image_features.view(-1, image_features.size(-1))

tensorrt_llm/_torch/auto_deploy/models/patches/mixtral.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,38 @@
77

88
from ...export.interface import BaseExportPatch, ExportPatchRegistry
99

10+
# Import SiLUActivation for compatibility check
11+
try:
12+
from transformers.activations import SiLUActivation
13+
14+
_SILU_TYPES = (nn.SiLU, SiLUActivation)
15+
except ImportError:
16+
_SILU_TYPES = (nn.SiLU,)
17+
18+
19+
def _is_silu_activation(act_fn) -> bool:
20+
"""Check if activation function is SiLU or equivalent."""
21+
return isinstance(act_fn, _SILU_TYPES)
22+
1023

1124
def _forward_moe(self: MixtralSparseMoeBlock, hidden_states: torch.Tensor):
1225
# check if we can apply the patch
13-
use_original_forward = False
14-
if not all(isinstance(expert.act_fn, nn.SiLU) for expert in self.experts):
15-
use_original_forward = True
26+
unsupported_reasons = []
27+
if not all(_is_silu_activation(expert.act_fn) for expert in self.experts):
28+
unsupported_reasons.append("expert activation is not SiLU")
1629

1730
if any(getattr(mod, "bias", None) is not None for mod in self.experts.modules()):
18-
use_original_forward = True
19-
20-
# rely on original forward instead
21-
if use_original_forward:
22-
return self._original_forward(hidden_states)
31+
unsupported_reasons.append("expert modules have bias")
32+
33+
# Raise informative error for unsupported configurations
34+
# (fallback to original forward is not export-compatible with transformers >= 4.57.1)
35+
if unsupported_reasons:
36+
raise NotImplementedError(
37+
f"MixtralSparseMoeBlock forward patch does not support this model configuration: "
38+
f"{', '.join(unsupported_reasons)}. "
39+
f"The original transformers forward uses torch.nonzero() and tensor indexing "
40+
f"which are not compatible with torch.export on meta tensors."
41+
)
2342

2443
batch_size, sequence_length, hidden_dim = hidden_states.shape
2544
if self.training and self.jitter_noise > 0:

tensorrt_llm/_torch/auto_deploy/models/patches/qwen3.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,38 @@
77

88
from ...export.interface import BaseExportPatch, ExportPatchRegistry
99

10+
# Import SiLUActivation for compatibility check
11+
try:
12+
from transformers.activations import SiLUActivation
13+
14+
_SILU_TYPES = (nn.SiLU, SiLUActivation)
15+
except ImportError:
16+
_SILU_TYPES = (nn.SiLU,)
17+
18+
19+
def _is_silu_activation(act_fn) -> bool:
20+
"""Check if activation function is SiLU or equivalent."""
21+
return isinstance(act_fn, _SILU_TYPES)
22+
1023

1124
def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor):
1225
# check if we can apply the patch
13-
use_original_forward = False
14-
if not all(isinstance(expert.act_fn, nn.SiLU) for expert in self.experts):
15-
use_original_forward = True
26+
unsupported_reasons = []
27+
if not all(_is_silu_activation(expert.act_fn) for expert in self.experts):
28+
unsupported_reasons.append("expert activation is not SiLU")
1629

1730
if any(getattr(mod, "bias", None) is not None for mod in self.experts.modules()):
18-
use_original_forward = True
19-
20-
# rely on original forward instead
21-
if use_original_forward:
22-
return self._original_forward(hidden_states)
31+
unsupported_reasons.append("expert modules have bias")
32+
33+
# Raise informative error for unsupported configurations
34+
# (fallback to original forward is not export-compatible with transformers >= 4.57.1)
35+
if unsupported_reasons:
36+
raise NotImplementedError(
37+
f"Qwen3MoeSparseMoeBlock forward patch does not support this model configuration: "
38+
f"{', '.join(unsupported_reasons)}. "
39+
f"The original transformers forward uses torch.nonzero() and tensor indexing "
40+
f"which are not compatible with torch.export on meta tensors."
41+
)
2342

2443
batch_size, sequence_length, hidden_dim = hidden_states.shape
2544
hidden_states = hidden_states.view(-1, hidden_dim)

tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import torch
32
from _model_test_utils import get_small_model_config
43
from build_and_run_ad import ExperimentConfig
@@ -10,10 +9,6 @@
109

1110

1211
def test_build_run_llama4_vlm():
13-
pytest.skip(
14-
"Skipping test_build_run_llm4_vlm because Llama4 is giving an error on upgrading transformers version to 4.57.1"
15-
"https://nvbugspro.nvidia.com/bug/5732942"
16-
)
1712
atol = 1e-3
1813
rtol = 1e-3
1914

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -201,19 +201,6 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
201201
],
202202
)
203203
def test_build_ad(model_hub_id: str, llm_extra_args: dict):
204-
if (
205-
model_hub_id == "mistralai/Mixtral-8x7B-Instruct-v0.1"
206-
and llm_extra_args.get("mode") != "transformers"
207-
):
208-
pytest.skip(
209-
"Mixtral-8x7B-Instruct-v0.1 is giving an error on upgrading transformers version to 4.57.1"
210-
"https://nvbugspro.nvidia.com/bug/5732942"
211-
)
212-
if model_hub_id == "Qwen/Qwen3-30B-A3B" and llm_extra_args.get("mode") != "transformers":
213-
pytest.skip(
214-
"Qwen3-30B-A3B is giving an error on upgrading transformers version to 4.57.1"
215-
"https://nvbugspro.nvidia.com/bug/5732942"
216-
)
217204
experiment_config = get_small_model_config(model_hub_id, **llm_extra_args)
218205
experiment_config["args"]["runtime"] = "demollm" # Default runtime set to demollm
219206
experiment_config["args"]["world_size"] = 0 # Default world_size set to 0

0 commit comments

Comments
 (0)