We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3c9ab0d commit 4bc7ed7Copy full SHA for 4bc7ed7
modelopt/torch/_deploy/utils/torch_onnx.py
@@ -393,7 +393,9 @@ def get_onnx_bytes_and_metadata(
393
# during inference.
394
input_none_names = list(set(tree_spec_input.names) - set(input_names))
395
396
- use_torch_autocast = not (is_fp4_quantized(model) or is_mxfp8_quantized(model))
+ use_torch_autocast = not (
397
+ is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
398
+ )
399
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
400
401
# Get output once (we export in inference mode - so also using inference mode here!)
0 commit comments