Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,9 @@ def get_onnx_bytes_and_metadata(
# during inference.
input_none_names = list(set(tree_spec_input.names) - set(input_names))

use_torch_autocast = not (is_fp4_quantized(model) or is_mxfp8_quantized(model))
use_torch_autocast = not (
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
)
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()

# Get output once (we export in inference mode - so also using inference mode here!)
Expand Down
Loading