diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index e31b2e481..356769514 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -790,8 +790,10 @@ def remove_input_dq_and_output_q( if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]: consumer.input[cons_idx] = q_node.output[0] else: - q_node_prev = tensor_producers[q_node.input[0]] - consumer.input[cons_idx] = q_node_prev.output[0] + q_node_prev = tensor_producers.get(q_node.input[0], None) + consumer.input[cons_idx] = ( + q_node_prev.output[0] if q_node_prev else q_node.input[0] + ) break # Track DequantizeLinear node indices for cleanup @@ -828,8 +830,11 @@ def remove_input_dq_and_output_q( if quantizable_custom_ops[producer.op_type]["out"]: dq_node[0].input[0] = producer.output[0] else: - dq_node_next = tensor_consumers[dq_node[0].output[0]] - dq_node_next[0].input[0] = producer.output[0] + dq_node_next = tensor_consumers.get(dq_node[0].output[0], None) + if dq_node_next: + dq_node_next[0].input[0] = producer.output[0] + else: + dq_node[0].input[0] = producer.output[0] # Track QuantizeLinear node indices for cleanup q_indices.append(node_idx) diff --git a/modelopt/onnx/trt_utils.py b/modelopt/onnx/trt_utils.py index 48a4a1618..85312ecd4 100644 --- a/modelopt/onnx/trt_utils.py +++ b/modelopt/onnx/trt_utils.py @@ -416,9 +416,10 @@ def interpret_trt_plugins_precision_flag( # Will add Q/DQ nodes in the requested I/O indices inp_precision_quant = [i for i, p in enumerate(inp_precision) if p in ["int8", "fp8"]] out_precision_quant = [i for i, p in enumerate(out_precision) if p in ["int8", "fp8"]] - custom_ops_to_quantize[op_type] = { - "inp": inp_precision_quant, - "out": out_precision_quant, - } + if inp_precision_quant or out_precision_quant: + custom_ops_to_quantize[op_type] = { + "inp": inp_precision_quant, + "out": out_precision_quant, + } return custom_ops_to_cast, custom_ops_to_quantize