Skip to content

Commit f1e8c01

Browse files
cehongwangChen Fu
andauthored
Added warnings if the model is in training mode (#3676)
Co-authored-by: Chen Fu <[email protected]>
1 parent 85b8e51 commit f1e8c01

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,10 @@ def compile_module(
732732
Returns:
733733
Compiled FX GraphModule
734734
"""
735+
if any(v.requires_grad for v in gm.state_dict().values()):
736+
logger.warning(
737+
"The model may be in training mode, which may affect the performance of the compiled model!"
738+
)
735739
dryrun_tracker = DryRunTracker()
736740
if sample_kwarg_inputs is None:
737741
sample_kwarg_inputs = {}

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from copy import deepcopy
55
from enum import Enum, auto
6-
from typing import Any, Dict, Iterator, Optional, Set, Union
6+
from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union
77

88
import numpy as np
99
import torch
@@ -70,7 +70,9 @@ def __init__(
7070
strict: bool = True,
7171
prefer_deferred_runtime_asserts_over_guards: bool = False,
7272
weight_streaming_budget: Optional[int] = None,
73-
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
73+
enabled_precisions: Union[
74+
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
75+
] = _defaults.ENABLED_PRECISIONS,
7476
**kwargs: Any,
7577
) -> None:
7678
"""
@@ -127,6 +129,10 @@ def __init__(
127129
self.refit_state = RefitState()
128130
self.pytorch_model = _make_refit_change_trigger(pytorch_model, self.refit_state)
129131
self.original_model = pytorch_model
132+
if pytorch_model.training:
133+
logger.warning(
134+
"The model may be in training mode, which may affect the performance of the compiled model!"
135+
)
130136
# Process settings
131137
self.gm: Any = None
132138
self.exp_program: Any = None
@@ -162,8 +168,6 @@ def __init__(
162168
"Weight stremaing budget is not set. Using auto weight streaming budget"
163169
)
164170
self.enabled_precisions = enabled_precisions
165-
if self.enabled_precisions is None:
166-
self.enabled_precisions = _defaults.ENABLED_PRECISIONS
167171

168172
cls = self.__class__
169173
self.__class__ = type(

0 commit comments

Comments
 (0)