Skip to content

Commit f713c0d

Browse files
committed
skipping test_refit_weight_stripped_engine_multiple_times if include_refit flag not there
1 parent d8c2cd0 commit f713c0d

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

tests/py/dynamo/models/test_weight_stripped_engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55
import unittest
66

7+
import tensorrt as trt
78
import torch
89
import torch_tensorrt as torch_trt
910
from torch.testing._internal.common_utils import TestCase
@@ -642,6 +643,10 @@ def test_refit_identical_engine_weights(self):
642643
not importlib.util.find_spec("torchvision"),
643644
"torchvision is not installed",
644645
)
646+
@unittest.skipIf(
647+
not hasattr(trt.SerializationFlag, "INCLUDE_REFIT"),
648+
"Multiple refit requires TensorRT >= 10.14 with INCLUDE_REFIT serialization flag",
649+
)
645650
def test_refit_weight_stripped_engine_multiple_times(self):
646651
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
647652
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)

tests/py/ts/api/test_classes.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,22 @@
1010
def is_blackwell():
1111
"""
1212
Check if running on NVIDIA Blackwell architecture (sm_90+).
13-
13+
1414
Blackwell architecture adds input/output reformat layers in TensorRT engines.
15-
15+
1616
Returns:
1717
bool: True if running on Blackwell (sm_90+), False otherwise
1818
"""
1919
if not torch.cuda.is_available():
2020
return False
21-
21+
2222
device_properties = torch.cuda.get_device_properties(0)
2323
compute_capability = device_properties.major * 10 + device_properties.minor
24-
24+
2525
# Blackwell is sm_90 and above
2626
return compute_capability >= 90
27+
28+
2729
@unittest.skipIf(
2830
not torchtrt.ENABLED_FEATURES.torchscript_frontend,
2931
"TorchScript Frontend is not available",
@@ -348,12 +350,13 @@ def test_get_layer_info(self):
348350
"""
349351

350352
import json
353+
351354
if is_blackwell():
352-
# blackwell has additional layers-
353-
#Layer 0: __mye88_myl0_0 ← Input reformat layer
354-
#Layer 1: aten__matmul(...) fc1 ← First matmul (fc1)
355-
#Layer 2: aten__matmul(...) fc2 ← Second matmul (fc2)
356-
#Layer 3: __mye90_myl0_3 ← Output reformat layer
355+
# blackwell has additional layers-
356+
# Layer 0: __mye88_myl0_0 ← Input reformat layer
357+
# Layer 1: aten__matmul(...) fc1 ← First matmul (fc1)
358+
# Layer 2: aten__matmul(...) fc2 ← Second matmul (fc2)
359+
# Layer 3: __mye90_myl0_3 ← Output reformat layer
357360
num_layers = 4
358361
else:
359362
num_layers = 2

0 commit comments

Comments
 (0)