diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 65569a9a85..6bf1b58f71 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -4,6 +4,7 @@ import shutil import unittest +import tensorrt as trt import torch import torch_tensorrt as torch_trt from torch.testing._internal.common_utils import TestCase @@ -642,6 +643,10 @@ def test_refit_identical_engine_weights(self): not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) + @unittest.skipIf( + not hasattr(trt.SerializationFlag, "INCLUDE_REFIT"), + "Multiple refit requires TensorRT >= 10.14 with INCLUDE_REFIT serialization flag", + ) def test_refit_weight_stripped_engine_multiple_times(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) diff --git a/tests/py/ts/api/test_classes.py b/tests/py/ts/api/test_classes.py index c51f90d721..796f57e046 100644 --- a/tests/py/ts/api/test_classes.py +++ b/tests/py/ts/api/test_classes.py @@ -7,6 +7,25 @@ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule +def is_blackwell(): + """ + Check if running on NVIDIA Blackwell architecture (sm_90+). + + Blackwell architecture adds input/output reformat layers in TensorRT engines. + + Returns: + bool: True if running on Blackwell (sm_90+), False otherwise + """ + if not torch.cuda.is_available(): + return False + + device_properties = torch.cuda.get_device_properties(0) + compute_capability = device_properties.major * 10 + device_properties.minor + + # Blackwell is sm_90 and above + return compute_capability >= 90 + + @unittest.skipIf( not torchtrt.ENABLED_FEATURES.torchscript_frontend, "TorchScript Frontend is not available", @@ -332,13 +351,22 @@ def test_get_layer_info(self): import json + if is_blackwell(): + # blackwell has additional layers- + # Layer 0: __mye88_myl0_0 ← Input reformat layer + # Layer 1: aten__matmul(...) fc1 ← First matmul (fc1) + # Layer 2: aten__matmul(...) fc2 ← Second matmul (fc2) + # Layer 3: __mye90_myl0_3 ← Output reformat layer + num_layers = 4 + else: + num_layers = 2 for trt_mod in ( TestTorchTensorRTModule._get_trt_mod(), TestTorchTensorRTModule._get_trt_mod(via_ts=True), ): trt_json = json.loads(trt_mod.get_layer_info()) [self.assertTrue(k in trt_json.keys()) for k in ["Layers", "Bindings"]] - self.assertTrue(len(trt_json["Layers"]) == 2) + self.assertTrue(len(trt_json["Layers"]) == num_layers) self.assertTrue(len(trt_json["Bindings"]) == 2)