diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0d42a8ba2..ed0b9699c 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,12 +20,14 @@ Model Optimizer Changelog (Linux) **Bug Fixes** - Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering. +- Fix Q/DQ/Cast node placements in 'FP32 required' tensors in custom ops in the ONNX quantization workflow. **New Features** - Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``). - Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md `_ for more details. - Add FP8/NVFP4 KV cache quantization support for Megatron Core models. +- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow. 0.39 (2025-11-11) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index b815c3406..4328c9fc2 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -194,6 +194,7 @@ def convert_to_f16( low_precision_type: str = "fp16", keep_io_types: bool = True, op_block_list: list[str] = [], + tensor_block_dict: dict[str, dict[str, list[int]]] = {}, trt_plugins: list[str] | None = [], ) -> onnx.ModelProto: """Convert model to mixed precision, using PrecisionConverter. @@ -204,8 +205,8 @@ def convert_to_f16( model: ONNX model to convert. low_precision_type: Target precision to reduce to ('fp16' or 'bf16'). keep_io_types: Whether to preserve input/output types. - disable_shape_infer: Whether to disable shape inference. op_block_list: List of operation types that should remain in FP32. + tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library). """ assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16" @@ -235,6 +236,7 @@ def convert_to_f16( keep_io_types=keep_io_types, low_precision_type=low_precision_type, custom_ops=sanitizer.custom_ops, + tensor_block_dict=tensor_block_dict, ) high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list] low_precision_nodes = [ diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 12a65f895..38820479c 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -99,6 +99,7 @@ def __init__( min_opset: int = 13, max_ir_version: int | None = None, trt_plugins: list[str] | None = [], + tensor_block_dict: dict[str, dict[str, list[int]]] = {}, ) -> None: """Initialize PrecisionConverter. @@ -112,6 +113,10 @@ def __init__( init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be cast at runtime. custom_ops: List of custom ops. + min_opset: Minimum opset for conversion. + max_ir_version: Max IR version for conversion. + trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library). + tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. """ self.model = deepcopy(model) self.value_info_map = value_info_map @@ -148,6 +153,9 @@ def __init__( ) ) + # Custom mapping of op types to indices of inputs that should not be converted to low precision + self.skip_inputs_map = self._create_skip_inputs_mapping(tensor_block_dict) + def convert( self, high_precision_nodes: list[str], @@ -211,7 +219,8 @@ def convert( # For the low precision nodes that take a FP32 input, we don't exclude it from # casting up so that the input can be converted to FP32 as expected. exclude_consumers = list( - set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name} + set(low_precision_nodes) + - {n.name for n in fp32_input_to_low_precision_node[tensor_name]} ) self._add_cast( tensor_name, @@ -467,12 +476,14 @@ def _filter_unsupported_op_types( return high_precision_nodes, low_precision_nodes def _get_tensors_to_cast( - self, low_precision_nodes: list[str] - ) -> tuple[list[str], list[str], dict[str, onnx.NodeProto]]: + self, + low_precision_nodes: list[str], + high_precision_tensors: dict[str, dict[str, list[int]]] = {}, + ) -> tuple[list[str], list[str], dict[str, list[onnx.NodeProto]]]: cast_to_fp16 = [] # Tensors to cast down to FP16 cast_to_fp32 = [] # Tensors to cast up to FP32 # Keep track of the low precision nodes that take a FP32 input. - fp32_input_to_low_precision_node = {} + fp32_input_to_low_precision_node = defaultdict(list) # Get tensors for FP16 nodes for node in self.model.graph.node: @@ -481,7 +492,7 @@ def _get_tensors_to_cast( for input in node.input: if self._should_skip_low_precision_input_conversion(node, input): cast_to_fp32.append(input) - fp32_input_to_low_precision_node[input] = node + fp32_input_to_low_precision_node[input].append(node) else: cast_to_fp16.append(input) @@ -536,7 +547,7 @@ def _convert_initializers( low_precision_nodes: List of node names that should use low precision initializers. high_precision_nodes: List of node names that should use high precision initializers. """ - # 1. Compute a mapping from initiailizers to high precision nodes & low precision nodes that use them. + # 1. Compute a mapping from initializers to high precision nodes & low precision nodes that use them. low_precision_nodes_set: set[str] = set(low_precision_nodes) high_precision_nodes_set: set[str] = set(high_precision_nodes) initializer_to_nodes: dict[str, InitializerConsumerTracker] = defaultdict( @@ -888,7 +899,7 @@ def _add_cast( ) if tensor_to_consumers is None: - utils.get_consumer_nodes(self.model, tensor_name) + consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name) else: consumer_nodes = tensor_to_consumers.get(tensor_name, []) consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers] @@ -1272,13 +1283,9 @@ def _sanitize_model(self): graph_sanitizer.sanitize() self.model = graph_sanitizer.model - def _should_skip_low_precision_input_conversion( - self, node: onnx.NodeProto, input_name: str - ) -> bool: - """Check if the input should be skipped for low precision conversion. - - This is used for nodes that have inputs that MUST remain in FP32. - """ + def _create_skip_inputs_mapping(self, tensor_block_dict: dict[str, dict[str, list[int]]] = {}): + """Create mapping of op types to indices of inputs that should not be converted to low precision.""" + skip_inputs_map = {} match self.low_precision_type.str_short: case "fp16": skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16 @@ -1287,12 +1294,27 @@ def _should_skip_low_precision_input_conversion( case _: raise ValueError(f"Unsupported low precision type: {self.low_precision_type}") - if node.op_type in skip_inputs_map: + # Update mapping with user-defined information + for op, tensor_map in tensor_block_dict.items(): + high_precision_tensor = tensor_map.get("inp", []) + if high_precision_tensor: + skip_inputs_map.update({op: set(high_precision_tensor)}) + + return skip_inputs_map + + def _should_skip_low_precision_input_conversion( + self, node: onnx.NodeProto, input_name: str + ) -> bool: + """Check if the input should be skipped for low precision conversion. + + This is used for nodes that have inputs that MUST remain in FP32. + """ + if node.op_type in self.skip_inputs_map: # Figure out the index of the input in the node input inputs_lst = list(node.input) if input_name not in inputs_lst: raise ValueError(f"Input {input_name} not found in node {node.name}.") input_index = inputs_lst.index(input_name) # Check if we should skip this input for low precision conversion - return input_index in skip_inputs_map[node.op_type] + return input_index in self.skip_inputs_map[node.op_type] return False diff --git a/modelopt/onnx/quantization/fp8.py b/modelopt/onnx/quantization/fp8.py index e1f092c12..bca898b0c 100755 --- a/modelopt/onnx/quantization/fp8.py +++ b/modelopt/onnx/quantization/fp8.py @@ -169,6 +169,7 @@ def quantize( op_types_to_quantize: list[str] | None = None, op_types_to_exclude: list[str] | None = None, op_types_to_exclude_fp16: list[str] | None = None, + custom_ops_to_cast_fp32: dict | None = None, nodes_to_quantize: list[str] | None = None, nodes_to_exclude: list[str] | None = None, use_external_data_format: bool = False, @@ -324,6 +325,7 @@ def quantize( onnx_model, keep_io_types=not direct_io_types, op_block_list=op_types_to_exclude_fp16 or [], + tensor_block_dict=custom_ops_to_cast_fp32 or {}, low_precision_type=high_precision_dtype, trt_plugins=trt_extra_plugin_lib_paths, ) diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index 01e2cd3da..01929667c 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -120,6 +120,7 @@ def quantize( op_types_to_quantize: list[str] | None = None, op_types_to_exclude: list[str] | None = None, op_types_to_exclude_fp16: list[str] | None = None, + custom_ops_to_cast_fp32: dict | None = None, nodes_to_quantize: list[str] | None = None, nodes_to_exclude: list[str] | None = None, use_external_data_format: bool = False, @@ -285,6 +286,7 @@ def quantize( onnx_model, keep_io_types=not direct_io_types, op_block_list=op_types_to_exclude_fp16 or [], + tensor_block_dict=custom_ops_to_cast_fp32 or {}, low_precision_type=high_precision_dtype, trt_plugins=trt_extra_plugin_lib_paths, ) diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 58e6b436c..66c613a6c 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -872,22 +872,32 @@ def remove_input_dq_and_output_q( ) # Only remove DQs from the inputs of custom ops - if consumers[0].op_type not in quantizable_custom_ops: + has_cast = consumers[0].op_type == "Cast" + consumers_2 = tensor_consumers[consumers[0].output[0]] if has_cast else consumers + if consumers_2[0].op_type not in quantizable_custom_ops: continue - # Rewire graph to connect Q with the node after DQ (skip DQ) - for consumer in consumers: - for cons_idx, cons_inp in enumerate(consumer.input): - if cons_inp == node.output[0]: - # If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ. - if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]: - consumer.input[cons_idx] = q_node.output[0] - else: - q_node_prev = tensor_producers.get(q_node.input[0], None) - consumer.input[cons_idx] = ( - q_node_prev.output[0] if q_node_prev else q_node.input[0] - ) - break + if has_cast: + # Assume that this input tensor is not meant to be quantized as there's a Cast node between DQ + # and the custom op. Keep the Cast node and delete both Q/DQ nodes. + q_node_prev = tensor_producers.get(q_node.input[0], None) + consumers[0].input[0] = ( + q_node_prev.output[0] if q_node_prev else q_node.input[0] + ) + else: + # Rewire graph to connect Q with the node after DQ (skip DQ) + for consumer in consumers: + for cons_idx, cons_inp in enumerate(consumer.input): + if cons_inp == node.output[0]: + # If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ. + if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]: + consumer.input[cons_idx] = q_node.output[0] + else: + q_node_prev = tensor_producers.get(q_node.input[0], None) + consumer.input[cons_idx] = ( + q_node_prev.output[0] if q_node_prev else q_node.input[0] + ) + break # Track DequantizeLinear node indices for cleanup dq_indices.append(node_idx) @@ -944,6 +954,11 @@ def remove_input_dq_and_output_q( f" {len(dq_indices)} DQ node{'' if len(dq_indices) == 1 else 's'}" ) + # Cleanup graph to remove any dangling Q/DQ nodes + graph = gs.import_onnx(onnx_model) + graph.cleanup() + onnx_model = gs.export_onnx(graph) + # TODO: remove manual ir_version change once ORT supports ir_version 11 onnx_model.ir_version = 10 diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 6349c128e..800124646 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -430,16 +430,6 @@ def quantize( ) trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type] - # Update list with op types to exclude from FP16/BF16 conversion - op_types_to_exclude_fp16 = list( - dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys())) - ) - if high_precision_dtype == "fp32" and op_types_to_exclude_fp16: - logger.warning( - "Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. " - "Since the model won't be converted to a lower precision, this flag is void." - ) - # Use random scales if calibration data is not supplied if calibration_data is None: calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes) @@ -485,6 +475,7 @@ def quantize( op_types_to_quantize=op_types_to_quantize, op_types_to_exclude=op_types_to_exclude, op_types_to_exclude_fp16=op_types_to_exclude_fp16, + custom_ops_to_cast_fp32=custom_ops_to_cast_fp32, nodes_to_quantize=nodes_to_quantize, nodes_to_exclude=nodes_to_exclude, use_external_data_format=use_external_data_format,