Skip to content

Commit 9574fd0

Browse files
committed
Do not use autocast for mxfp8
Signed-off-by: ajrasane <[email protected]>
1 parent 6db6ffa commit 9574fd0

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

.github/workflows/gpu_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
runs-on: linux-amd64-gpu-l4-latest-1
6363
timeout-minutes: 120
6464
container: &gpu_container
65-
image: nvcr.io/nvidia/pytorch:25.08-py3
65+
image: nvcr.io/nvidia/pytorch:25.06-py3
6666
env:
6767
GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py
6868
PIP_CONSTRAINT: "" # Disable pip constraint for upgrading packages

.gitlab/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ example-trtllm:
5454

5555
example-onnx:
5656
extends: example-torch
57-
image: nvcr.io/nvidia/tensorrt:25.08-py3
57+
image: nvcr.io/nvidia/tensorrt:25.06-py3
5858
tags: [docker, linux, 2-gpu, sm>=89]
5959
parallel:
6060
matrix:

modelopt/onnx/export/mxfp8_exporter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,34 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
166166
attr.s = b"tanh"
167167
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")
168168

169+
# Insert cast to fp16 after Sqrt nodes
170+
cast_nodes_to_insert = []
171+
for idx, node in enumerate(graph.node):
172+
if node.op_type == "Sqrt":
173+
sqrt_output = node.output[0]
174+
cast_output = f"{sqrt_output}_cast_fp16"
175+
176+
# Create Cast node
177+
cast_node = onnx.helper.make_node(
178+
"Cast",
179+
inputs=[sqrt_output],
180+
outputs=[cast_output],
181+
to=onnx_dtype_map["Half"],
182+
name=f"{node.name}_cast_fp16",
183+
)
184+
cast_nodes_to_insert.append((idx + 1, cast_node))
185+
186+
# Update consumers to use cast output
187+
for consumer in graph.node:
188+
if consumer == node:
189+
continue
190+
for i, inp in enumerate(consumer.input):
191+
if inp == sqrt_output:
192+
consumer.input[i] = cast_output
193+
194+
# Insert Cast nodes in reverse order to preserve indices
195+
for offset, (pos, cast_node) in enumerate(cast_nodes_to_insert):
196+
graph.node.insert(pos + offset, cast_node)
197+
logger.debug(f"Inserted Cast to FP16 after {cast_node.input[0]}")
198+
169199
return onnx_model

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def get_onnx_bytes_and_metadata(
580580
except StopIteration:
581581
param_dtype = torch.float32
582582
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
583-
if is_int4_quantized(model):
583+
if is_int4_quantized(model) or is_mxfp8_quantized(model):
584584
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
585585
onnx_opt_graph = convert_float_to_float16(
586586
onnx_opt_graph,

0 commit comments

Comments
 (0)