|
3 | 3 | import warnings
|
4 | 4 | from copy import deepcopy
|
5 | 5 | from enum import Enum, auto
|
6 |
| -from typing import Any, Dict, Iterator, Optional, Set, Union |
| 6 | +from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union |
7 | 7 |
|
8 | 8 | import numpy as np
|
9 | 9 | import torch
|
@@ -70,7 +70,9 @@ def __init__(
|
70 | 70 | strict: bool = True,
|
71 | 71 | prefer_deferred_runtime_asserts_over_guards: bool = False,
|
72 | 72 | weight_streaming_budget: Optional[int] = None,
|
73 |
| - enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, |
| 73 | + enabled_precisions: Union[ |
| 74 | + Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] |
| 75 | + ] = _defaults.ENABLED_PRECISIONS, |
74 | 76 | **kwargs: Any,
|
75 | 77 | ) -> None:
|
76 | 78 | """
|
@@ -127,6 +129,10 @@ def __init__(
|
127 | 129 | self.refit_state = RefitState()
|
128 | 130 | self.pytorch_model = _make_refit_change_trigger(pytorch_model, self.refit_state)
|
129 | 131 | self.original_model = pytorch_model
|
| 132 | + if pytorch_model.training: |
| 133 | + logger.warning( |
| 134 | + "The model may be in training mode, which may affect the performance of the compiled model!" |
| 135 | + ) |
130 | 136 | # Process settings
|
131 | 137 | self.gm: Any = None
|
132 | 138 | self.exp_program: Any = None
|
@@ -162,8 +168,6 @@ def __init__(
|
162 | 168 | "Weight stremaing budget is not set. Using auto weight streaming budget"
|
163 | 169 | )
|
164 | 170 | self.enabled_precisions = enabled_precisions
|
165 |
| - if self.enabled_precisions is None: |
166 |
| - self.enabled_precisions = _defaults.ENABLED_PRECISIONS |
167 | 171 |
|
168 | 172 | cls = self.__class__
|
169 | 173 | self.__class__ = type(
|
|
0 commit comments