We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 468a6e6 commit 3d65e2cCopy full SHA for 3d65e2c
py/torch_tensorrt/dynamo/conversion/converter_utils.py
@@ -516,7 +516,8 @@ def get_trt_tensor(
516
# If the input is 64-bit, cast it to 32-bit for TRT freezing
517
if isinstance(input_val, torch.Tensor) and ctx.compilation_settings.truncate_double:
518
if input_val.dtype == torch.float64:
519
- input_val = input_val.to(torch.float32)
+ with unset_fake_temporarily():
520
+ input_val = input_val.to(torch.float32)
521
elif isinstance(input_val, np.ndarray) and ctx.compilation_settings.truncate_double:
522
if input_val.dtype == np.float64:
523
input_val = input_val.astype(np.float32)
0 commit comments