Skip to content

Commit 918d081

Browse files
committed
Integrate autocast for mxfp8
Signed-off-by: ajrasane <[email protected]>
1 parent 17d59a4 commit 918d081

File tree

3 files changed

+2
-21
lines changed

3 files changed

+2
-21
lines changed

modelopt/onnx/export/mxfp8_exporter.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,4 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
166166
attr.s = b"tanh"
167167
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")
168168

169-
def is_fp32_cast(node: onnx.NodeProto) -> bool:
170-
return node.op_type == "Cast" and any(
171-
attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute
172-
)
173-
174-
# Remove Cast nodes after specific operators
175-
nodes_to_remove = []
176-
for node in graph.node:
177-
if node.op_type in ["Transpose", "Reshape", "Sqrt", "Add", "Gelu"]:
178-
child_nodes = [n for n in graph.node if node.output[0] in n.input]
179-
if len(child_nodes) == 1 and is_fp32_cast(child_nodes[0]):
180-
cast_node = child_nodes[0]
181-
node.output.clear()
182-
node.output.extend(cast_node.output)
183-
nodes_to_remove.append(cast_node.name)
184-
185-
# Remove unnecessary casts
186-
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
187-
graph.node.extend(new_nodes)
188-
189169
return onnx_model

modelopt/onnx/trt_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _map_trt_to_onnx_type(trt_type: trt.DataType):
140140
trt.bool: onnx.TensorProto.BOOL,
141141
trt.fp8: onnx.TensorProto.FLOAT8E4M3FN,
142142
trt.fp4: onnx.TensorProto.FLOAT4E2M1,
143+
trt.e8m0: onnx.TensorProto.UINT8,
143144
}
144145
try:
145146
return trt_to_onnx_dtype_mapping[trt_type]

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def get_onnx_bytes_and_metadata(
576576
except StopIteration:
577577
param_dtype = torch.float32
578578
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
579-
if is_mxfp8_quantized(model) or is_int4_quantized(model):
579+
if is_int4_quantized(model):
580580
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
581581
onnx_opt_graph = convert_float_to_float16(
582582
onnx_opt_graph,

0 commit comments

Comments
 (0)