You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat!: Changing the default behavior for selecting the input type
BREAKING CHANGE: This commit changes the default behavior of
the compiler where if the user does not specify an input data
type explicity instead of using the enabled precision, now
the compiler will inspect the model provided to infer the
data type for the input that will not cause an error if
the model was run in torch. In practice this means
- If the weights are in FP32 for the first tensor calculation
then default input type is FP32
- If the weights are in FP16 for the first tensor calculation
then default input type is FP16
- etc.
If the data type cannot be determined the compiler will
default to FP32.
This calculation is done per input tensor so if one input
is inferred to use FP32 and another INT32 then the expected
types will be the same (FP32, INT32)
As was the same before if the user defines the data type
explicitly or provides an example tensor the data type
specified there will be respected
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Copy file name to clipboardExpand all lines: py/trtorch/Input.py
+36-8Lines changed: 36 additions & 8 deletions
Original file line number
Diff line number
Diff line change
@@ -30,7 +30,7 @@ class _ShapeMode(Enum):
30
30
31
31
shape_mode=None#: (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
32
32
shape=None#: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
33
-
dtype=_types.dtype.float32#: The expected data type of the input tensor (default: trtorch.dtype.float32)
33
+
dtype=_types.dtype.unknown#: The expected data type of the input tensor (default: trtorch.dtype.float32)
34
34
_explicit_set_dtype=False
35
35
format=_types.TensorFormat.contiguous#: The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
0 commit comments