2323import modelopt .onnx .autocast .utils as utils
2424import modelopt .onnx .utils as onnx_utils
2525from modelopt .onnx .autocast .logging_config import logger
26+ from modelopt .onnx .quantization .graph_utils import cast_custom_ops
27+ from modelopt .onnx .trt_utils import interpret_trt_plugins_precision_flag
2628
2729
2830class GraphSanitizer :
@@ -34,6 +36,7 @@ def __init__(
3436 min_opset : int = 13 ,
3537 max_ir_version : int | None = None ,
3638 trt_plugins : list [str ] | None = [],
39+ trt_plugins_precision : list [str ] | None = [],
3740 ) -> None :
3841 """Initialize GraphSanitizer.
3942
@@ -48,7 +51,9 @@ def __init__(
4851 self .max_ir_version = max_ir_version
4952 self .standard_ops = {schema .name for schema in onnx .defs .get_all_schemas ()}
5053 self .custom_ops = None
54+ self .custom_ops_low_precision_nodes = []
5155 self .trt_plugins = trt_plugins
56+ self .trt_plugins_precision = trt_plugins_precision or []
5257
5358 def sanitize (self ) -> None :
5459 """Sanitize the model graph.
@@ -67,6 +72,7 @@ def sanitize(self) -> None:
6772 self .cleanup_model ()
6873 self .set_ir_version (self .max_ir_version )
6974 self .convert_fp64_to_fp32 ()
75+ self .ensure_custom_ops_precision ()
7076
7177 def convert_fp64_to_fp32 (self ) -> None :
7278 """Convert FP64 initializers, I/O types, and specific nodes to FP32."""
@@ -88,6 +94,19 @@ def convert_fp64_to_fp32(self) -> None:
8894 logger .info ("Converted FP64 initializers, I/O types, and nodes to FP32" )
8995 self .model = onnx_utils .infer_shapes (self .model , strict_mode = True )
9096
97+ def ensure_custom_ops_precision (self ) -> None :
98+ """Ensure that custom ops run in the requested precision."""
99+ custom_ops_to_cast , _ = interpret_trt_plugins_precision_flag (
100+ self .model ,
101+ self .trt_plugins_precision ,
102+ )
103+ if custom_ops_to_cast .get ("fp16" , {}):
104+ self .model = cast_custom_ops (self .model , custom_ops_to_cast ["fp16" ])
105+ self .custom_ops_low_precision_nodes = [
106+ n .name for n in self .model .graph .node if n .op_type in custom_ops_to_cast ["fp16" ]
107+ ]
108+ logger .info ("Ensured custom ops precision" )
109+
91110 def find_custom_nodes (self ) -> None :
92111 """Find custom nodes in the model.
93112
0 commit comments