diff --git a/tensorrt_llm/_torch/auto_deploy/export/library/meta_nonzero.py b/tensorrt_llm/_torch/auto_deploy/export/library/meta_nonzero.py new file mode 100644 index 00000000000..0a1f79f7a44 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/export/library/meta_nonzero.py @@ -0,0 +1,37 @@ +"""Patch to enable torch.nonzero() on meta tensors during export. + +This patch addresses an issue where torch.nonzero() raises NotImplementedError +when tracing models that use nonzero on meta device. The fix sets the config +flag to assume all elements are non-zero, which enables export to proceed. +""" + +import torch.fx.experimental._config as fx_config + +from ..interface import BaseExportPatch, ExportPatchRegistry + + +@ExportPatchRegistry.register("meta_nonzero") +class MetaNonzeroPatch(BaseExportPatch): + """Patch to enable torch.nonzero() meta registration during export. + + This patch sets torch.fx.experimental._config.meta_nonzero_assume_all_nonzero + to True, allowing torch.nonzero() to work on meta tensors during tracing. + The implementation assumes all elements are non-zero, which is acceptable + for tracing purposes where only shapes matter. + """ + + def _apply_patch(self): + """Apply the meta nonzero patch.""" + # Store original config value + self.original_values["meta_nonzero_assume_all_nonzero"] = getattr( + fx_config, "meta_nonzero_assume_all_nonzero", False + ) + + # Enable nonzero on meta tensors + fx_config.meta_nonzero_assume_all_nonzero = True + + def _revert_patch(self): + """Revert the meta nonzero patch.""" + fx_config.meta_nonzero_assume_all_nonzero = self.original_values[ + "meta_nonzero_assume_all_nonzero" + ] diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py index f6cfb5a6365..d91d7062d50 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py @@ -72,11 +72,11 @@ def _forward_with_cond( inputs_embeds = self.get_input_embeddings()(input_ids) def _vision_branch(inputs_embeds, pixel_values, input_ids): + # Updated to match transformers 4.57.1+ signature + # get_image_features now takes: (self, pixel_values, vision_feature_select_strategy, **kwargs) image_features = self.get_image_features( pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=None, ) vision_flat = image_features.view(-1, image_features.size(-1)) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/mixtral.py b/tensorrt_llm/_torch/auto_deploy/models/patches/mixtral.py index b759fe6495d..c370de2285a 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/mixtral.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/mixtral.py @@ -7,19 +7,38 @@ from ...export.interface import BaseExportPatch, ExportPatchRegistry +# Import SiLUActivation for compatibility check +try: + from transformers.activations import SiLUActivation + + _SILU_TYPES = (nn.SiLU, SiLUActivation) +except ImportError: + _SILU_TYPES = (nn.SiLU,) + + +def _is_silu_activation(act_fn) -> bool: + """Check if activation function is SiLU or equivalent.""" + return isinstance(act_fn, _SILU_TYPES) + def _forward_moe(self: MixtralSparseMoeBlock, hidden_states: torch.Tensor): # check if we can apply the patch - use_original_forward = False - if not all(isinstance(expert.act_fn, nn.SiLU) for expert in self.experts): - use_original_forward = True + unsupported_reasons = [] + if not all(_is_silu_activation(expert.act_fn) for expert in self.experts): + unsupported_reasons.append("expert activation is not SiLU") if any(getattr(mod, "bias", None) is not None for mod in self.experts.modules()): - use_original_forward = True - - # rely on original forward instead - if use_original_forward: - return self._original_forward(hidden_states) + unsupported_reasons.append("expert modules have bias") + + # Raise informative error for unsupported configurations + # (fallback to original forward is not export-compatible with transformers >= 4.57.1) + if unsupported_reasons: + raise NotImplementedError( + f"MixtralSparseMoeBlock forward patch does not support this model configuration: " + f"{', '.join(unsupported_reasons)}. " + f"The original transformers forward uses torch.nonzero() and tensor indexing " + f"which are not compatible with torch.export on meta tensors." + ) batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/qwen3.py b/tensorrt_llm/_torch/auto_deploy/models/patches/qwen3.py index 3870bc5bfd8..e5ed0c6af29 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/qwen3.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/qwen3.py @@ -7,19 +7,38 @@ from ...export.interface import BaseExportPatch, ExportPatchRegistry +# Import SiLUActivation for compatibility check +try: + from transformers.activations import SiLUActivation + + _SILU_TYPES = (nn.SiLU, SiLUActivation) +except ImportError: + _SILU_TYPES = (nn.SiLU,) + + +def _is_silu_activation(act_fn) -> bool: + """Check if activation function is SiLU or equivalent.""" + return isinstance(act_fn, _SILU_TYPES) + def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor): # check if we can apply the patch - use_original_forward = False - if not all(isinstance(expert.act_fn, nn.SiLU) for expert in self.experts): - use_original_forward = True + unsupported_reasons = [] + if not all(_is_silu_activation(expert.act_fn) for expert in self.experts): + unsupported_reasons.append("expert activation is not SiLU") if any(getattr(mod, "bias", None) is not None for mod in self.experts.modules()): - use_original_forward = True - - # rely on original forward instead - if use_original_forward: - return self._original_forward(hidden_states) + unsupported_reasons.append("expert modules have bias") + + # Raise informative error for unsupported configurations + # (fallback to original forward is not export-compatible with transformers >= 4.57.1) + if unsupported_reasons: + raise NotImplementedError( + f"Qwen3MoeSparseMoeBlock forward patch does not support this model configuration: " + f"{', '.join(unsupported_reasons)}. " + f"The original transformers forward uses torch.nonzero() and tensor indexing " + f"which are not compatible with torch.export on meta tensors." + ) batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py index a1879ed30a8..af28829e732 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py @@ -1,4 +1,3 @@ -import pytest import torch from _model_test_utils import get_small_model_config from build_and_run_ad import ExperimentConfig @@ -10,10 +9,6 @@ def test_build_run_llama4_vlm(): - pytest.skip( - "Skipping test_build_run_llm4_vlm because Llama4 is giving an error on upgrading transformers version to 4.57.1" - "https://nvbugspro.nvidia.com/bug/5732942" - ) atol = 1e-3 rtol = 1e-3 diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 4b2c75f29d9..79457fbfca5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -201,19 +201,6 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): ], ) def test_build_ad(model_hub_id: str, llm_extra_args: dict): - if ( - model_hub_id == "mistralai/Mixtral-8x7B-Instruct-v0.1" - and llm_extra_args.get("mode") != "transformers" - ): - pytest.skip( - "Mixtral-8x7B-Instruct-v0.1 is giving an error on upgrading transformers version to 4.57.1" - "https://nvbugspro.nvidia.com/bug/5732942" - ) - if model_hub_id == "Qwen/Qwen3-30B-A3B" and llm_extra_args.get("mode") != "transformers": - pytest.skip( - "Qwen3-30B-A3B is giving an error on upgrading transformers version to 4.57.1" - "https://nvbugspro.nvidia.com/bug/5732942" - ) experiment_config = get_small_model_config(model_hub_id, **llm_extra_args) experiment_config["args"]["runtime"] = "demollm" # Default runtime set to demollm experiment_config["args"]["world_size"] = 0 # Default world_size set to 0