|
7 | 7 | from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule |
8 | 8 |
|
9 | 9 |
|
| 10 | +def is_blackwell(): |
| 11 | + """ |
| 12 | + Check if running on NVIDIA Blackwell architecture (sm_90+). |
| 13 | +
|
| 14 | + Blackwell architecture adds input/output reformat layers in TensorRT engines. |
| 15 | +
|
| 16 | + Returns: |
| 17 | + bool: True if running on Blackwell (sm_90+), False otherwise |
| 18 | + """ |
| 19 | + if not torch.cuda.is_available(): |
| 20 | + return False |
| 21 | + |
| 22 | + device_properties = torch.cuda.get_device_properties(0) |
| 23 | + compute_capability = device_properties.major * 10 + device_properties.minor |
| 24 | + |
| 25 | + # Blackwell is sm_90 and above |
| 26 | + return compute_capability >= 90 |
| 27 | + |
| 28 | + |
10 | 29 | @unittest.skipIf( |
11 | 30 | not torchtrt.ENABLED_FEATURES.torchscript_frontend, |
12 | 31 | "TorchScript Frontend is not available", |
@@ -332,13 +351,22 @@ def test_get_layer_info(self): |
332 | 351 |
|
333 | 352 | import json |
334 | 353 |
|
| 354 | + if is_blackwell(): |
| 355 | + # blackwell has additional layers- |
| 356 | + # Layer 0: __mye88_myl0_0 ← Input reformat layer |
| 357 | + # Layer 1: aten__matmul(...) fc1 ← First matmul (fc1) |
| 358 | + # Layer 2: aten__matmul(...) fc2 ← Second matmul (fc2) |
| 359 | + # Layer 3: __mye90_myl0_3 ← Output reformat layer |
| 360 | + num_layers = 4 |
| 361 | + else: |
| 362 | + num_layers = 2 |
335 | 363 | for trt_mod in ( |
336 | 364 | TestTorchTensorRTModule._get_trt_mod(), |
337 | 365 | TestTorchTensorRTModule._get_trt_mod(via_ts=True), |
338 | 366 | ): |
339 | 367 | trt_json = json.loads(trt_mod.get_layer_info()) |
340 | 368 | [self.assertTrue(k in trt_json.keys()) for k in ["Layers", "Bindings"]] |
341 | | - self.assertTrue(len(trt_json["Layers"]) == 2) |
| 369 | + self.assertTrue(len(trt_json["Layers"]) == num_layers) |
342 | 370 | self.assertTrue(len(trt_json["Bindings"]) == 2) |
343 | 371 |
|
344 | 372 |
|
|
0 commit comments