Skip to content

Commit 5f34cfd

Browse files
committed
feat: Add support for Groot N1.5 model
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent a93266a commit 5f34cfd

File tree

4 files changed

+22
-117
lines changed

4 files changed

+22
-117
lines changed

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,25 @@ def matrix_multiply(
4848
input, other = broadcast(
4949
ctx, input, other, f"{name}_input", f"{name}_other", preset_diff
5050
)
51-
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
52-
promoted_type = _enums.dtype._from(
53-
torch.promote_types(
54-
_enums.dtype._from(input.dtype).to(torch.dtype),
55-
_enums.dtype._from(other.dtype).to(torch.dtype),
56-
)
51+
if (
52+
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
53+
and ctx.compilation_settings.use_fp32_acc
54+
):
55+
input = cast_trt_tensor(ctx, input, torch.float32, f"{name}_input_casted")
56+
other = cast_trt_tensor(ctx, other, torch.float32, f"{name}_other_casted")
57+
58+
matmul_layer = ctx.net.add_matrix_multiply(
59+
input, input_matrix_op, other, other_matrix_op
60+
)
61+
matmul_output = matmul_layer.get_output(0)
62+
63+
if (
64+
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
65+
and ctx.compilation_settings.use_fp32_acc
66+
):
67+
matmul_output = cast_trt_tensor(
68+
ctx, matmul_output, torch.float16, f"{name}_output_casted"
5769
)
58-
trt_promoted_type = promoted_type.to(trt.DataType)
59-
input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted")
60-
other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted")
6170

62-
layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
63-
set_layer_name(layer, target, name, source_ir)
64-
return layer.get_output(0)
71+
set_layer_name(matmul_layer, target, name, source_ir)
72+
return matmul_output

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch_tensorrt.dynamo._settings import CompilationSettings
66
from torch_tensorrt.dynamo.utils import is_tegra_platform
77

8-
from .accumulate_fp32_matmul import accumulate_fp32_matmul
98
from .complex_graph_rewrite import complex_graph_detection
109
from .constant_folding import constant_fold
1110
from .fuse_distributed_ops import fuse_distributed_ops
@@ -25,7 +24,6 @@
2524
fuse_prims_broadcast,
2625
replace_max_pool_with_indices,
2726
remove_assert_nodes,
28-
accumulate_fp32_matmul,
2927
remove_num_users_is_0_nodes,
3028
complex_graph_detection,
3129
]

py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py

Lines changed: 0 additions & 102 deletions
This file was deleted.

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ def compile(self) -> None:
368368
enabled_precisions=self.enabled_precisions,
369369
**self.additional_settings,
370370
)
371-
deallocate_module(self.original_model, delete_module=False)
371+
if self.additional_settings.get("offload_module_to_cpu", False):
372+
deallocate_module(self.original_model, delete_module=False)
372373
if self.enable_weight_streaming:
373374
self.set_weight_streaming_ctx(self.weight_streaming_budget)
374375

0 commit comments

Comments
 (0)