Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ Model Optimizer Changelog (Linux)
**Bug Fixes**

- Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering.
- Fix Q/DQ/Cast node placements in 'FP32 required' tensors in custom ops in the ONNX quantization workflow.

**New Features**

- Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.

0.39 (2025-11-11)
^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 3 additions & 1 deletion modelopt/onnx/autocast/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def convert_to_f16(
low_precision_type: str = "fp16",
keep_io_types: bool = True,
op_block_list: list[str] = [],
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
trt_plugins: list[str] | None = [],
) -> onnx.ModelProto:
"""Convert model to mixed precision, using PrecisionConverter.
Expand All @@ -204,8 +205,8 @@ def convert_to_f16(
model: ONNX model to convert.
low_precision_type: Target precision to reduce to ('fp16' or 'bf16').
keep_io_types: Whether to preserve input/output types.
disable_shape_infer: Whether to disable shape inference.
op_block_list: List of operation types that should remain in FP32.
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
"""
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"
Expand Down Expand Up @@ -235,6 +236,7 @@ def convert_to_f16(
keep_io_types=keep_io_types,
low_precision_type=low_precision_type,
custom_ops=sanitizer.custom_ops,
tensor_block_dict=tensor_block_dict,
)
high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list]
low_precision_nodes = [
Expand Down
54 changes: 38 additions & 16 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
) -> None:
"""Initialize PrecisionConverter.

Expand All @@ -112,6 +113,10 @@ def __init__(
init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be
cast at runtime.
custom_ops: List of custom ops.
min_opset: Minimum opset for conversion.
max_ir_version: Max IR version for conversion.
trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library).
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
"""
self.model = deepcopy(model)
self.value_info_map = value_info_map
Expand Down Expand Up @@ -148,6 +153,9 @@ def __init__(
)
)

# Custom mapping of op types to indices of inputs that should not be converted to low precision
self.skip_inputs_map = self._create_skip_inputs_mapping(tensor_block_dict)

def convert(
self,
high_precision_nodes: list[str],
Expand Down Expand Up @@ -211,7 +219,8 @@ def convert(
# For the low precision nodes that take a FP32 input, we don't exclude it from
# casting up so that the input can be converted to FP32 as expected.
exclude_consumers = list(
set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name}
set(low_precision_nodes)
- {n.name for n in fp32_input_to_low_precision_node[tensor_name]}
)
self._add_cast(
tensor_name,
Expand Down Expand Up @@ -467,12 +476,14 @@ def _filter_unsupported_op_types(
return high_precision_nodes, low_precision_nodes

def _get_tensors_to_cast(
self, low_precision_nodes: list[str]
) -> tuple[list[str], list[str], dict[str, onnx.NodeProto]]:
self,
low_precision_nodes: list[str],
high_precision_tensors: dict[str, dict[str, list[int]]] = {},
) -> tuple[list[str], list[str], dict[str, list[onnx.NodeProto]]]:
cast_to_fp16 = [] # Tensors to cast down to FP16
cast_to_fp32 = [] # Tensors to cast up to FP32
# Keep track of the low precision nodes that take a FP32 input.
fp32_input_to_low_precision_node = {}
fp32_input_to_low_precision_node = defaultdict(list)

# Get tensors for FP16 nodes
for node in self.model.graph.node:
Expand All @@ -481,7 +492,7 @@ def _get_tensors_to_cast(
for input in node.input:
if self._should_skip_low_precision_input_conversion(node, input):
cast_to_fp32.append(input)
fp32_input_to_low_precision_node[input] = node
fp32_input_to_low_precision_node[input].append(node)
else:
cast_to_fp16.append(input)

Expand Down Expand Up @@ -536,7 +547,7 @@ def _convert_initializers(
low_precision_nodes: List of node names that should use low precision initializers.
high_precision_nodes: List of node names that should use high precision initializers.
"""
# 1. Compute a mapping from initiailizers to high precision nodes & low precision nodes that use them.
# 1. Compute a mapping from initializers to high precision nodes & low precision nodes that use them.
low_precision_nodes_set: set[str] = set(low_precision_nodes)
high_precision_nodes_set: set[str] = set(high_precision_nodes)
initializer_to_nodes: dict[str, InitializerConsumerTracker] = defaultdict(
Expand Down Expand Up @@ -888,7 +899,7 @@ def _add_cast(
)

if tensor_to_consumers is None:
utils.get_consumer_nodes(self.model, tensor_name)
consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name)
else:
consumer_nodes = tensor_to_consumers.get(tensor_name, [])
consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers]
Expand Down Expand Up @@ -1272,13 +1283,9 @@ def _sanitize_model(self):
graph_sanitizer.sanitize()
self.model = graph_sanitizer.model

def _should_skip_low_precision_input_conversion(
self, node: onnx.NodeProto, input_name: str
) -> bool:
"""Check if the input should be skipped for low precision conversion.

This is used for nodes that have inputs that MUST remain in FP32.
"""
def _create_skip_inputs_mapping(self, tensor_block_dict: dict[str, dict[str, list[int]]] = {}):
"""Create mapping of op types to indices of inputs that should not be converted to low precision."""
skip_inputs_map = {}
match self.low_precision_type.str_short:
case "fp16":
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
Expand All @@ -1287,12 +1294,27 @@ def _should_skip_low_precision_input_conversion(
case _:
raise ValueError(f"Unsupported low precision type: {self.low_precision_type}")

if node.op_type in skip_inputs_map:
# Update mapping with user-defined information
for op, tensor_map in tensor_block_dict.items():
high_precision_tensor = tensor_map.get("inp", [])
if high_precision_tensor:
skip_inputs_map.update({op: set(high_precision_tensor)})

return skip_inputs_map

def _should_skip_low_precision_input_conversion(
self, node: onnx.NodeProto, input_name: str
) -> bool:
"""Check if the input should be skipped for low precision conversion.

This is used for nodes that have inputs that MUST remain in FP32.
"""
if node.op_type in self.skip_inputs_map:
# Figure out the index of the input in the node input
inputs_lst = list(node.input)
if input_name not in inputs_lst:
raise ValueError(f"Input {input_name} not found in node {node.name}.")
input_index = inputs_lst.index(input_name)
# Check if we should skip this input for low precision conversion
return input_index in skip_inputs_map[node.op_type]
return input_index in self.skip_inputs_map[node.op_type]
return False
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def quantize(
op_types_to_quantize: list[str] | None = None,
op_types_to_exclude: list[str] | None = None,
op_types_to_exclude_fp16: list[str] | None = None,
custom_ops_to_cast_fp32: dict | None = None,
nodes_to_quantize: list[str] | None = None,
nodes_to_exclude: list[str] | None = None,
use_external_data_format: bool = False,
Expand Down Expand Up @@ -324,6 +325,7 @@ def quantize(
onnx_model,
keep_io_types=not direct_io_types,
op_block_list=op_types_to_exclude_fp16 or [],
tensor_block_dict=custom_ops_to_cast_fp32 or {},
low_precision_type=high_precision_dtype,
trt_plugins=trt_extra_plugin_lib_paths,
)
Expand Down
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def quantize(
op_types_to_quantize: list[str] | None = None,
op_types_to_exclude: list[str] | None = None,
op_types_to_exclude_fp16: list[str] | None = None,
custom_ops_to_cast_fp32: dict | None = None,
nodes_to_quantize: list[str] | None = None,
nodes_to_exclude: list[str] | None = None,
use_external_data_format: bool = False,
Expand Down Expand Up @@ -285,6 +286,7 @@ def quantize(
onnx_model,
keep_io_types=not direct_io_types,
op_block_list=op_types_to_exclude_fp16 or [],
tensor_block_dict=custom_ops_to_cast_fp32 or {},
low_precision_type=high_precision_dtype,
trt_plugins=trt_extra_plugin_lib_paths,
)
Expand Down
43 changes: 29 additions & 14 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,22 +872,32 @@ def remove_input_dq_and_output_q(
)

# Only remove DQs from the inputs of custom ops
if consumers[0].op_type not in quantizable_custom_ops:
has_cast = consumers[0].op_type == "Cast"
consumers_2 = tensor_consumers[consumers[0].output[0]] if has_cast else consumers
if consumers_2[0].op_type not in quantizable_custom_ops:
continue

# Rewire graph to connect Q with the node after DQ (skip DQ)
for consumer in consumers:
for cons_idx, cons_inp in enumerate(consumer.input):
if cons_inp == node.output[0]:
# If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]:
consumer.input[cons_idx] = q_node.output[0]
else:
q_node_prev = tensor_producers.get(q_node.input[0], None)
consumer.input[cons_idx] = (
q_node_prev.output[0] if q_node_prev else q_node.input[0]
)
break
if has_cast:
# Assume that this input tensor is not meant to be quantized as there's a Cast node between DQ
# and the custom op. Keep the Cast node and delete both Q/DQ nodes.
q_node_prev = tensor_producers.get(q_node.input[0], None)
consumers[0].input[0] = (
q_node_prev.output[0] if q_node_prev else q_node.input[0]
)
else:
# Rewire graph to connect Q with the node after DQ (skip DQ)
for consumer in consumers:
for cons_idx, cons_inp in enumerate(consumer.input):
if cons_inp == node.output[0]:
# If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]:
consumer.input[cons_idx] = q_node.output[0]
else:
q_node_prev = tensor_producers.get(q_node.input[0], None)
consumer.input[cons_idx] = (
q_node_prev.output[0] if q_node_prev else q_node.input[0]
)
break

# Track DequantizeLinear node indices for cleanup
dq_indices.append(node_idx)
Expand Down Expand Up @@ -944,6 +954,11 @@ def remove_input_dq_and_output_q(
f" {len(dq_indices)} DQ node{'' if len(dq_indices) == 1 else 's'}"
)

# Cleanup graph to remove any dangling Q/DQ nodes
graph = gs.import_onnx(onnx_model)
graph.cleanup()
onnx_model = gs.export_onnx(graph)

# TODO: remove manual ir_version change once ORT supports ir_version 11
onnx_model.ir_version = 10

Expand Down
11 changes: 1 addition & 10 deletions modelopt/onnx/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,6 @@ def quantize(
)
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]

# Update list with op types to exclude from FP16/BF16 conversion
op_types_to_exclude_fp16 = list(
dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys()))
)
if high_precision_dtype == "fp32" and op_types_to_exclude_fp16:
logger.warning(
"Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
"Since the model won't be converted to a lower precision, this flag is void."
)

# Use random scales if calibration data is not supplied
if calibration_data is None:
calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
Expand Down Expand Up @@ -485,6 +475,7 @@ def quantize(
op_types_to_quantize=op_types_to_quantize,
op_types_to_exclude=op_types_to_exclude,
op_types_to_exclude_fp16=op_types_to_exclude_fp16,
custom_ops_to_cast_fp32=custom_ops_to_cast_fp32,
nodes_to_quantize=nodes_to_quantize,
nodes_to_exclude=nodes_to_exclude,
use_external_data_format=use_external_data_format,
Expand Down