Skip to content

Commit 4716131

Browse files
authored
Avoid autocast at onnx export if fp32 model is desired (#304)
Signed-off-by: Riyad Islam <[email protected]>
1 parent 85b309f commit 4716131

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,9 @@ def get_onnx_bytes_and_metadata(
393393
# during inference.
394394
input_none_names = list(set(tree_spec_input.names) - set(input_names))
395395

396-
use_torch_autocast = not (is_fp4_quantized(model) or is_mxfp8_quantized(model))
396+
use_torch_autocast = not (
397+
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
398+
)
397399
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
398400

399401
# Get output once (we export in inference mode - so also using inference mode here!)

0 commit comments

Comments
 (0)