|
15 | 15 |
|
16 | 16 | """MXFP8 quantization exporter.""" |
17 | 17 |
|
| 18 | +import numpy as np |
18 | 19 | import onnx |
| 20 | +from onnx import numpy_helper |
| 21 | + |
| 22 | +from modelopt.onnx.logging_config import logger |
| 23 | +from modelopt.onnx.quantization.graph_utils import get_tensor_producer_nodes |
| 24 | +from modelopt.onnx.quantization.qdq_utils import _cast_fp8, onnx_dtype_map |
| 25 | +from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax |
| 26 | +from modelopt.onnx.utils import get_attribute, has_attribute |
19 | 27 |
|
20 | 28 | from .base_exporter import ONNXQuantExporter |
21 | 29 |
|
| 30 | +E8_M0_BIAS = 127 |
| 31 | +DEFAULT_BLOCK_SIZE = 32 |
| 32 | +DEFAULT_QUANT_AXIS = -1 |
| 33 | + |
| 34 | + |
| 35 | +def _get_weight_dq_nodes(graph: onnx.GraphProto) -> list[onnx.NodeProto]: |
| 36 | + """Get weight DequantizeLinear nodes from the graph.""" |
| 37 | + return [ |
| 38 | + node |
| 39 | + for node in graph.node |
| 40 | + if node.op_type == "TRT_MXFP8DequantizeLinear" |
| 41 | + and any(".weight" in inp for inp in node.input) |
| 42 | + ] |
| 43 | + |
| 44 | + |
| 45 | +def _get_quant_params(node: onnx.NodeProto) -> tuple[int, int]: |
| 46 | + """Extract quantization axis and block size from a node.""" |
| 47 | + if has_attribute(node, "axis"): |
| 48 | + quant_axis = int(get_attribute(node, "axis")) |
| 49 | + else: |
| 50 | + quant_axis = DEFAULT_QUANT_AXIS |
| 51 | + logger.warning( |
| 52 | + "axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1" |
| 53 | + ) |
| 54 | + |
| 55 | + if has_attribute(node, "block_size"): |
| 56 | + block_size = int(get_attribute(node, "block_size")) |
| 57 | + else: |
| 58 | + block_size = DEFAULT_BLOCK_SIZE |
| 59 | + logger.warning( |
| 60 | + "block_size attribute not found for MXFP8DequantizeLinear node. " |
| 61 | + "Setting block_size to 32" |
| 62 | + ) |
| 63 | + |
| 64 | + return quant_axis, block_size |
| 65 | + |
22 | 66 |
|
23 | | -# TODO: Implement the MXFP8QuantExporter |
24 | 67 | class MXFP8QuantExporter(ONNXQuantExporter): |
25 | 68 | """Exporter for MXFP8 quantization.""" |
26 | 69 |
|
27 | 70 | @staticmethod |
28 | 71 | def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: |
29 | 72 | """Pre-processes the ONNX model for MXFP8 quantization.""" |
| 73 | + return onnx_model |
30 | 74 |
|
31 | 75 | @staticmethod |
32 | 76 | def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: |
33 | | - """Computes the scales for the weights in the ONNX model for MXFP8 quantization.""" |
| 77 | + """Computes the e8m0 scales for weights in the ONNX model for MXFP8 quantization.""" |
| 78 | + logger.info("Computing MXFP8 scales for weights") |
| 79 | + graph = onnx_model.graph |
| 80 | + initializer_map = {init.name: init for init in graph.initializer} |
| 81 | + tensor_producer_map = get_tensor_producer_nodes(graph) |
| 82 | + |
| 83 | + for node in _get_weight_dq_nodes(graph): |
| 84 | + weight_name = node.input[0] |
| 85 | + logger.debug(f"Computing MXFP8 scale for weight {weight_name}") |
| 86 | + |
| 87 | + weight = numpy_helper.to_array(initializer_map[weight_name]) |
| 88 | + quant_axis, block_size = _get_quant_params(node) |
| 89 | + |
| 90 | + # Compute scales |
| 91 | + amax = get_amax(weight, quant_axis, block_size) |
| 92 | + se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size) |
| 93 | + se8m0 = se8m0_fp32.astype(np.uint8) |
| 94 | + |
| 95 | + # Remove scale producer if it's a Constant node |
| 96 | + scale_name = node.input[1] |
| 97 | + scale_producer = tensor_producer_map[scale_name] |
| 98 | + if scale_producer.op_type == "Constant": |
| 99 | + graph.node.remove(scale_producer) |
| 100 | + |
| 101 | + # Create and add new scale tensor |
| 102 | + scale_name_new = scale_name.replace("Constant_output_0", "scale") |
| 103 | + scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name_new) |
| 104 | + graph.initializer.append(scale_tensor) |
| 105 | + node.input[1] = scale_name_new |
| 106 | + |
| 107 | + return onnx_model |
34 | 108 |
|
35 | 109 | @staticmethod |
36 | 110 | def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: |
37 | | - """Compresses the weights in the ONNX model for MXFP8 quantization.""" |
| 111 | + """Compresses the weights in the ONNX model to FP8 format for MXFP8 quantization.""" |
| 112 | + logger.info("Compressing weights to MXFP8 format") |
| 113 | + graph = onnx_model.graph |
| 114 | + initializer_map = {init.name: init for init in graph.initializer} |
| 115 | + |
| 116 | + for node in _get_weight_dq_nodes(graph): |
| 117 | + weight_name = node.input[0] |
| 118 | + scale_name = node.input[1] |
| 119 | + logger.debug(f"Compressing weight {weight_name} to MXFP8") |
| 120 | + |
| 121 | + weight = numpy_helper.to_array(initializer_map[weight_name]) |
| 122 | + quant_axis, block_size = _get_quant_params(node) |
| 123 | + |
| 124 | + # Get scale and convert back to fp32 for computation |
| 125 | + se8m0 = numpy_helper.to_array(initializer_map[scale_name]) |
| 126 | + se8m0_fp32 = se8m0.astype(np.float32) |
| 127 | + |
| 128 | + # Expand block array so that it can be broadcasted with weight |
| 129 | + se8m0_fp32_expanded = np.repeat(se8m0_fp32, block_size, axis=quant_axis) |
| 130 | + scaled_weight = weight / np.exp2(se8m0_fp32_expanded - E8_M0_BIAS) |
| 131 | + |
| 132 | + # Create FP8 weight tensor |
| 133 | + weights_e4m3 = onnx.helper.make_tensor( |
| 134 | + name=weight_name, |
| 135 | + data_type=onnx_dtype_map["Float8"], |
| 136 | + dims=[*scaled_weight.shape], |
| 137 | + vals=_cast_fp8(scaled_weight).tobytes(), |
| 138 | + raw=True, |
| 139 | + ) |
| 140 | + initializer_map[weight_name].CopyFrom(weights_e4m3) |
| 141 | + logger.debug(f"Converted {weight_name} to MXFP8") |
| 142 | + |
| 143 | + return onnx_model |
38 | 144 |
|
39 | 145 | @staticmethod |
40 | 146 | def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: |
41 | | - """Post-processes the ONNX model for MXFP8 quantization.""" |
| 147 | + """Post-processes the ONNX model for MXFP8 quantization. |
| 148 | +
|
| 149 | + Sets DQ output type to FP16 and updates GELU nodes to use tanh approximation. |
| 150 | + """ |
| 151 | + logger.info("Post-processing MXFP8 quantized model") |
| 152 | + graph = onnx_model.graph |
| 153 | + |
| 154 | + # Set output type of DQ to FP16 |
| 155 | + for node in graph.node: |
| 156 | + if node.op_type == "TRT_MXFP8DequantizeLinear": |
| 157 | + for attr in node.attribute: |
| 158 | + if attr.name == "output_dtype": |
| 159 | + attr.i = onnx_dtype_map["Half"] |
| 160 | + |
| 161 | + # Currently only tanh approximation is supported for Gelu |
| 162 | + for node in graph.node: |
| 163 | + if node.op_type == "Gelu": |
| 164 | + for attr in node.attribute: |
| 165 | + if attr.name == "approximate": |
| 166 | + attr.s = b"tanh" |
| 167 | + logger.debug(f"Updated GELU node {node.name} to use tanh approximation") |
| 168 | + |
| 169 | + def is_fp32_cast(node: onnx.NodeProto) -> bool: |
| 170 | + return node.op_type == "Cast" and any( |
| 171 | + attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute |
| 172 | + ) |
| 173 | + |
| 174 | + # Remove Cast nodes after specific operators |
| 175 | + nodes_to_remove = [] |
| 176 | + for node in graph.node: |
| 177 | + if node.op_type in ["Transpose", "Reshape", "Sqrt", "Add", "Gelu"]: |
| 178 | + child_nodes = [n for n in graph.node if node.output[0] in n.input] |
| 179 | + if len(child_nodes) == 1 and is_fp32_cast(child_nodes[0]): |
| 180 | + cast_node = child_nodes[0] |
| 181 | + node.output.clear() |
| 182 | + node.output.extend(cast_node.output) |
| 183 | + nodes_to_remove.append(cast_node.name) |
| 184 | + |
| 185 | + # Remove unnecessary casts |
| 186 | + new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] |
| 187 | + graph.node.extend(new_nodes) |
| 188 | + |
| 189 | + return onnx_model |
0 commit comments