Skip to content

Commit 91bb090

Browse files
authored
feat: Add support for Groot N1.5 model (#3736)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 40d4d41 commit 91bb090

File tree

8 files changed

+70
-217
lines changed

8 files changed

+70
-217
lines changed

docsrc/user_guide/mixed_precision.rst

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
.. _mixed_precision:
22

33
Compile Mixed Precision models with Torch-TensorRT
4-
====================================
4+
===================================================
55
.. currentmodule:: torch_tensorrt.dynamo
66

77
.. automodule:: torch_tensorrt.dynamo
88
:members:
99
:undoc-members:
1010
:show-inheritance:
1111

12-
Consider the following Pytorch model which explicitly casts intermediate layer to run in FP16.
12+
Explicit Typing
13+
---------------
14+
15+
Consider the following PyTorch model which explicitly casts intermediate layer to run in FP16.
1316

1417
.. code-block:: python
1518
@@ -54,6 +57,7 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option
5457

5558
.. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions.
5659

60+
5761
.. code-block:: python
5862
5963
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
@@ -62,7 +66,7 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option
6266
with torch_tensorrt.logging.debug():
6367
trt_gm = torch_tensorrt.dynamo.compile(ep,
6468
inputs=inputs,
65-
use_explicit_typing=True
69+
use_explicit_typing=True,
6670
debug=True)
6771
6872
# Debug log info
@@ -71,4 +75,36 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option
7175
# Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata:
7276
# Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata:
7377
74-
Now the ``linear2`` layer runs in FP16 as shown in the above logs.
78+
Now the ``linear2`` layer runs in FP16 as shown in the above logs.
79+
80+
81+
82+
FP32 Accumulation
83+
-----------------
84+
85+
When ``use_fp32_acc=True`` is set, Torch-TensorRT will attempt to use FP32 accumulation for matmul layers, even if the input and output tensors are in FP16. This is particularly useful for models that are sensitive to numerical errors introduced by lower-precision accumulation.
86+
87+
.. important::
88+
89+
When enabling ``use_fp32_acc=True``, **explicit typing must be enabled** by setting ``use_explicit_typing=True``. Without ``use_explicit_typing=True``, the accumulation type may not be properly respected, and you may not see the intended numerical benefits.
90+
91+
.. code-block:: python
92+
93+
inputs = [torch.randn((1, 10), dtype=torch.float16).cuda()]
94+
mod = MyModule().eval().cuda()
95+
ep = torch.export.export(mod, tuple(inputs))
96+
with torch_tensorrt.logging.debug():
97+
trt_gm = torch_tensorrt.dynamo.compile(
98+
ep,
99+
inputs=inputs,
100+
use_fp32_acc=True,
101+
use_explicit_typing=True, # Explicit typing must be enabled
102+
debug=True
103+
)
104+
105+
# Debug log info
106+
# Layers:
107+
# Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ ... ], Outputs: [ ... ], Format/Datatype: Half, Accumulation: Float
108+
# ...
109+
110+
For more information on these settings, see the explicit typing examples above.

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ def index_dtype_validator(
420420

421421

422422
@dynamo_tensorrt_converter(
423-
torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
423+
torch.ops.aten.index.Tensor,
424+
capability_validator=index_dtype_validator,
425+
supports_dynamic_shapes=True,
424426
)
425427
@enforce_tensor_types(
426428
{

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@
2424
from torch.fx.node import Argument, Target
2525
from torch.fx.passes.shape_prop import TensorMetadata
2626
from torch_tensorrt import _enums
27+
from torch_tensorrt._utils import is_tensorrt_version_supported
2728
from torch_tensorrt.dynamo._settings import CompilationSettings
2829
from torch_tensorrt.dynamo._SourceIR import SourceIR
2930
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
3031
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3132
ConverterRegistry,
3233
DynamoConverterImplSignature,
3334
)
34-
from torch_tensorrt._utils import is_tensorrt_version_supported
35+
3536
from ..types import Shape, TRTDataType, TRTLayer, TRTTensor
3637

3738
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -72,6 +73,9 @@ def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str:
7273
# If the provided data is a scalar, return it as is
7374
elif isinstance(metadata, (int, float, bool)):
7475
return f"{metadata}@Python-{type(metadata)}"
76+
# If the provided data is a SymInt, return it as is
77+
elif isinstance(metadata, (torch.SymInt)):
78+
return f"{metadata}@SymInt"
7579
# If the provided data is a sequence, recursively parse it
7680
elif isinstance(metadata, collections.abc.Sequence):
7781
formatted_str = "("

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,25 @@ def matrix_multiply(
4848
input, other = broadcast(
4949
ctx, input, other, f"{name}_input", f"{name}_other", preset_diff
5050
)
51-
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
52-
promoted_type = _enums.dtype._from(
53-
torch.promote_types(
54-
_enums.dtype._from(input.dtype).to(torch.dtype),
55-
_enums.dtype._from(other.dtype).to(torch.dtype),
56-
)
51+
if (
52+
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
53+
and ctx.compilation_settings.use_fp32_acc
54+
):
55+
input = cast_trt_tensor(ctx, input, torch.float32, f"{name}_input_casted")
56+
other = cast_trt_tensor(ctx, other, torch.float32, f"{name}_other_casted")
57+
58+
matmul_layer = ctx.net.add_matrix_multiply(
59+
input, input_matrix_op, other, other_matrix_op
60+
)
61+
matmul_output = matmul_layer.get_output(0)
62+
63+
if (
64+
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
65+
and ctx.compilation_settings.use_fp32_acc
66+
):
67+
matmul_output = cast_trt_tensor(
68+
ctx, matmul_output, torch.float16, f"{name}_output_casted"
5769
)
58-
trt_promoted_type = promoted_type.to(trt.DataType)
59-
input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted")
60-
other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted")
6170

62-
layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
63-
set_layer_name(layer, target, name, source_ir)
64-
return layer.get_output(0)
71+
set_layer_name(matmul_layer, target, name, source_ir)
72+
return matmul_output

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch_tensorrt.dynamo._settings import CompilationSettings
66
from torch_tensorrt.dynamo.utils import is_tegra_platform
77

8-
from .accumulate_fp32_matmul import accumulate_fp32_matmul
98
from .complex_graph_rewrite import complex_graph_detection
109
from .constant_folding import constant_fold
1110
from .fuse_prims_broadcast import fuse_prims_broadcast
@@ -24,7 +23,6 @@
2423
fuse_prims_broadcast,
2524
replace_max_pool_with_indices,
2625
remove_assert_nodes,
27-
accumulate_fp32_matmul,
2826
remove_num_users_is_0_nodes,
2927
complex_graph_detection,
3028
]

py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py

Lines changed: 0 additions & 115 deletions
This file was deleted.

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def compile(self) -> None:
367367
enabled_precisions=self.enabled_precisions,
368368
**self.additional_settings,
369369
)
370-
deallocate_module(self.original_model, delete_module=False)
370+
if self.additional_settings.get("offload_module_to_cpu", False):
371+
deallocate_module(self.original_model, delete_module=False)
371372
if self.enable_weight_streaming:
372373
self.set_weight_streaming_ctx(self.weight_streaming_budget)
373374

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -156,87 +156,6 @@ def forward(self, x):
156156
torch._dynamo.reset()
157157

158158

159-
class TestFP32Accumulation(TestCase):
160-
def test_fp32_acc(self):
161-
class FP32Acc(torch.nn.Module):
162-
def forward(self, input, weight):
163-
out = torch.ops.aten.mm.default(input, weight)
164-
return out
165-
166-
inputs = [
167-
torch.rand((3, 4)).cuda(),
168-
torch.rand((4, 5)).cuda(),
169-
]
170-
171-
fx_graph = torch.fx.symbolic_trace(FP32Acc())
172-
expected_ops = {torch.ops.aten._to_copy.default, torch.ops.aten.mm.default}
173-
unexpected_ops = {}
174-
175-
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
176-
fx_graph,
177-
inputs,
178-
expected_ops=expected_ops,
179-
unexpected_ops=unexpected_ops,
180-
min_block_size=1,
181-
use_fp32_acc=True,
182-
)
183-
184-
self.assertEqual(
185-
len(unexpected_ops_seen),
186-
0,
187-
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
188-
)
189-
190-
self.assertEqual(
191-
len(expected_ops_unseen),
192-
0,
193-
f"The following expected ops were not encountered: {expected_ops_unseen}",
194-
)
195-
torch._dynamo.reset()
196-
197-
def test_fp32_acc_for_addmm(self):
198-
class FP32Acc(torch.nn.Module):
199-
def forward(self, input, mat1, mat2):
200-
out = torch.ops.aten.addmm.default(input, mat1, mat2, beta=20, alpha=2)
201-
return out
202-
203-
inputs = [
204-
torch.rand((3, 5)).cuda(),
205-
torch.rand((3, 4)).cuda(),
206-
torch.rand((4, 5)).cuda(),
207-
]
208-
209-
fx_graph = torch.fx.symbolic_trace(FP32Acc())
210-
expected_ops = {
211-
torch.ops.aten._to_copy.default,
212-
torch.ops.aten.mm.default,
213-
torch.ops.aten.add.Tensor,
214-
}
215-
unexpected_ops = {}
216-
217-
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
218-
fx_graph,
219-
inputs,
220-
expected_ops=expected_ops,
221-
unexpected_ops=unexpected_ops,
222-
min_block_size=1,
223-
use_fp32_acc=True,
224-
)
225-
226-
self.assertEqual(
227-
len(unexpected_ops_seen),
228-
0,
229-
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
230-
)
231-
232-
self.assertEqual(
233-
len(expected_ops_unseen),
234-
0,
235-
f"The following expected ops were not encountered: {expected_ops_unseen}",
236-
)
237-
torch._dynamo.reset()
238-
239-
240159
class TestComplexSubgraph(TestCase):
241160
def test_complex_subgraph(self):
242161
BATCH = 1

0 commit comments

Comments
 (0)