diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index ac617f724..86267b33a 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -393,7 +393,9 @@ def get_onnx_bytes_and_metadata( # during inference. input_none_names = list(set(tree_spec_input.names) - set(input_names)) - use_torch_autocast = not (is_fp4_quantized(model) or is_mxfp8_quantized(model)) + use_torch_autocast = not ( + is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" + ) autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() # Get output once (we export in inference mode - so also using inference mode here!)