Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/export/library/meta_nonzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Patch to enable torch.nonzero() on meta tensors during export.
This patch addresses an issue where torch.nonzero() raises NotImplementedError
when tracing models that use nonzero on meta device. The fix sets the config
flag to assume all elements are non-zero, which enables export to proceed.
"""

import torch.fx.experimental._config as fx_config

from ..interface import BaseExportPatch, ExportPatchRegistry


@ExportPatchRegistry.register("meta_nonzero")
class MetaNonzeroPatch(BaseExportPatch):
"""Patch to enable torch.nonzero() meta registration during export.
This patch sets torch.fx.experimental._config.meta_nonzero_assume_all_nonzero
to True, allowing torch.nonzero() to work on meta tensors during tracing.
The implementation assumes all elements are non-zero, which is acceptable
for tracing purposes where only shapes matter.
"""

def _apply_patch(self):
"""Apply the meta nonzero patch."""
# Store original config value
self.original_values["meta_nonzero_assume_all_nonzero"] = getattr(
fx_config, "meta_nonzero_assume_all_nonzero", False
)

# Enable nonzero on meta tensors
fx_config.meta_nonzero_assume_all_nonzero = True

def _revert_patch(self):
"""Revert the meta nonzero patch."""
fx_config.meta_nonzero_assume_all_nonzero = self.original_values[
"meta_nonzero_assume_all_nonzero"
]
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def _forward_with_cond(
inputs_embeds = self.get_input_embeddings()(input_ids)

def _vision_branch(inputs_embeds, pixel_values, input_ids):
# Updated to match transformers 4.57.1+ signature
# get_image_features now takes: (self, pixel_values, vision_feature_select_strategy, **kwargs)
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=None,
)

vision_flat = image_features.view(-1, image_features.size(-1))
Expand Down
35 changes: 27 additions & 8 deletions tensorrt_llm/_torch/auto_deploy/models/patches/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,38 @@

from ...export.interface import BaseExportPatch, ExportPatchRegistry

# Import SiLUActivation for compatibility check
try:
from transformers.activations import SiLUActivation

_SILU_TYPES = (nn.SiLU, SiLUActivation)
except ImportError:
_SILU_TYPES = (nn.SiLU,)


def _is_silu_activation(act_fn) -> bool:
"""Check if activation function is SiLU or equivalent."""
return isinstance(act_fn, _SILU_TYPES)


def _forward_moe(self: MixtralSparseMoeBlock, hidden_states: torch.Tensor):
# check if we can apply the patch
use_original_forward = False
if not all(isinstance(expert.act_fn, nn.SiLU) for expert in self.experts):
use_original_forward = True
unsupported_reasons = []
if not all(_is_silu_activation(expert.act_fn) for expert in self.experts):
unsupported_reasons.append("expert activation is not SiLU")

if any(getattr(mod, "bias", None) is not None for mod in self.experts.modules()):
use_original_forward = True

# rely on original forward instead
if use_original_forward:
return self._original_forward(hidden_states)
unsupported_reasons.append("expert modules have bias")

# Raise informative error for unsupported configurations
# (fallback to original forward is not export-compatible with transformers >= 4.57.1)
if unsupported_reasons:
raise NotImplementedError(
f"MixtralSparseMoeBlock forward patch does not support this model configuration: "
f"{', '.join(unsupported_reasons)}. "
f"The original transformers forward uses torch.nonzero() and tensor indexing "
f"which are not compatible with torch.export on meta tensors."
)

batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
Expand Down
35 changes: 27 additions & 8 deletions tensorrt_llm/_torch/auto_deploy/models/patches/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,38 @@

from ...export.interface import BaseExportPatch, ExportPatchRegistry

# Import SiLUActivation for compatibility check
try:
from transformers.activations import SiLUActivation

_SILU_TYPES = (nn.SiLU, SiLUActivation)
except ImportError:
_SILU_TYPES = (nn.SiLU,)


def _is_silu_activation(act_fn) -> bool:
"""Check if activation function is SiLU or equivalent."""
return isinstance(act_fn, _SILU_TYPES)


def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor):
# check if we can apply the patch
use_original_forward = False
if not all(isinstance(expert.act_fn, nn.SiLU) for expert in self.experts):
use_original_forward = True
unsupported_reasons = []
if not all(_is_silu_activation(expert.act_fn) for expert in self.experts):
unsupported_reasons.append("expert activation is not SiLU")

if any(getattr(mod, "bias", None) is not None for mod in self.experts.modules()):
use_original_forward = True

# rely on original forward instead
if use_original_forward:
return self._original_forward(hidden_states)
unsupported_reasons.append("expert modules have bias")

# Raise informative error for unsupported configurations
# (fallback to original forward is not export-compatible with transformers >= 4.57.1)
if unsupported_reasons:
raise NotImplementedError(
f"Qwen3MoeSparseMoeBlock forward patch does not support this model configuration: "
f"{', '.join(unsupported_reasons)}. "
f"The original transformers forward uses torch.nonzero() and tensor indexing "
f"which are not compatible with torch.export on meta tensors."
)

batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import torch
from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig
Expand All @@ -10,10 +9,6 @@


def test_build_run_llama4_vlm():
pytest.skip(
"Skipping test_build_run_llm4_vlm because Llama4 is giving an error on upgrading transformers version to 4.57.1"
"https://nvbugspro.nvidia.com/bug/5732942"
)
atol = 1e-3
rtol = 1e-3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,6 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
],
)
def test_build_ad(model_hub_id: str, llm_extra_args: dict):
if (
model_hub_id == "mistralai/Mixtral-8x7B-Instruct-v0.1"
and llm_extra_args.get("mode") != "transformers"
):
pytest.skip(
"Mixtral-8x7B-Instruct-v0.1 is giving an error on upgrading transformers version to 4.57.1"
"https://nvbugspro.nvidia.com/bug/5732942"
)
if model_hub_id == "Qwen/Qwen3-30B-A3B" and llm_extra_args.get("mode") != "transformers":
pytest.skip(
"Qwen3-30B-A3B is giving an error on upgrading transformers version to 4.57.1"
"https://nvbugspro.nvidia.com/bug/5732942"
)
experiment_config = get_small_model_config(model_hub_id, **llm_extra_args)
experiment_config["args"]["runtime"] = "demollm" # Default runtime set to demollm
experiment_config["args"]["world_size"] = 0 # Default world_size set to 0
Expand Down
Loading