Skip to content

Commit 61bbd69

Browse files
authored
DLFW 26.01 changes to main (#4004)
1 parent ab59e43 commit 61bbd69

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

tests/py/dynamo/models/test_weight_stripped_engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55
import unittest
66

7+
import tensorrt as trt
78
import torch
89
import torch_tensorrt as torch_trt
910
from torch.testing._internal.common_utils import TestCase
@@ -642,6 +643,10 @@ def test_refit_identical_engine_weights(self):
642643
not importlib.util.find_spec("torchvision"),
643644
"torchvision is not installed",
644645
)
646+
@unittest.skipIf(
647+
not hasattr(trt.SerializationFlag, "INCLUDE_REFIT"),
648+
"Multiple refit requires TensorRT >= 10.14 with INCLUDE_REFIT serialization flag",
649+
)
645650
def test_refit_weight_stripped_engine_multiple_times(self):
646651
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
647652
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)

tests/py/ts/api/test_classes.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule
88

99

10+
def is_blackwell():
11+
"""
12+
Check if running on NVIDIA Blackwell architecture (sm_90+).
13+
14+
Blackwell architecture adds input/output reformat layers in TensorRT engines.
15+
16+
Returns:
17+
bool: True if running on Blackwell (sm_90+), False otherwise
18+
"""
19+
if not torch.cuda.is_available():
20+
return False
21+
22+
device_properties = torch.cuda.get_device_properties(0)
23+
compute_capability = device_properties.major * 10 + device_properties.minor
24+
25+
# Blackwell is sm_90 and above
26+
return compute_capability >= 90
27+
28+
1029
@unittest.skipIf(
1130
not torchtrt.ENABLED_FEATURES.torchscript_frontend,
1231
"TorchScript Frontend is not available",
@@ -332,13 +351,22 @@ def test_get_layer_info(self):
332351

333352
import json
334353

354+
if is_blackwell():
355+
# blackwell has additional layers-
356+
# Layer 0: __mye88_myl0_0 ← Input reformat layer
357+
# Layer 1: aten__matmul(...) fc1 ← First matmul (fc1)
358+
# Layer 2: aten__matmul(...) fc2 ← Second matmul (fc2)
359+
# Layer 3: __mye90_myl0_3 ← Output reformat layer
360+
num_layers = 4
361+
else:
362+
num_layers = 2
335363
for trt_mod in (
336364
TestTorchTensorRTModule._get_trt_mod(),
337365
TestTorchTensorRTModule._get_trt_mod(via_ts=True),
338366
):
339367
trt_json = json.loads(trt_mod.get_layer_info())
340368
[self.assertTrue(k in trt_json.keys()) for k in ["Layers", "Bindings"]]
341-
self.assertTrue(len(trt_json["Layers"]) == 2)
369+
self.assertTrue(len(trt_json["Layers"]) == num_layers)
342370
self.assertTrue(len(trt_json["Bindings"]) == 2)
343371

344372

0 commit comments

Comments
 (0)