2323import  modelopt .onnx .autocast .utils  as  utils 
2424import  modelopt .onnx .utils  as  onnx_utils 
2525from  modelopt .onnx .autocast .logging_config  import  logger 
26+ from  modelopt .onnx .quantization .graph_utils  import  cast_custom_ops 
27+ from  modelopt .onnx .trt_utils  import  interpret_trt_plugins_precision_flag 
2628
2729
2830class  GraphSanitizer :
@@ -34,6 +36,7 @@ def __init__(
3436        min_opset : int  =  13 ,
3537        max_ir_version : int  |  None  =  None ,
3638        trt_plugins : list [str ] |  None  =  [],
39+         trt_plugins_precision : list [str ] |  None  =  [],
3740    ) ->  None :
3841        """Initialize GraphSanitizer. 
3942
@@ -48,7 +51,9 @@ def __init__(
4851        self .max_ir_version  =  max_ir_version 
4952        self .standard_ops  =  {schema .name  for  schema  in  onnx .defs .get_all_schemas ()}
5053        self .custom_ops  =  None 
54+         self .custom_ops_low_precision_nodes  =  []
5155        self .trt_plugins  =  trt_plugins 
56+         self .trt_plugins_precision  =  trt_plugins_precision  or  []
5257
5358    def  sanitize (self ) ->  None :
5459        """Sanitize the model graph. 
@@ -67,6 +72,7 @@ def sanitize(self) -> None:
6772        self .cleanup_model ()
6873        self .set_ir_version (self .max_ir_version )
6974        self .convert_fp64_to_fp32 ()
75+         self .ensure_custom_ops_precision ()
7076
7177    def  convert_fp64_to_fp32 (self ) ->  None :
7278        """Convert FP64 initializers, I/O types, and specific nodes to FP32.""" 
@@ -88,6 +94,19 @@ def convert_fp64_to_fp32(self) -> None:
8894            logger .info ("Converted FP64 initializers, I/O types, and nodes to FP32" )
8995            self .model  =  onnx_utils .infer_shapes (self .model , strict_mode = True )
9096
97+     def  ensure_custom_ops_precision (self ) ->  None :
98+         """Ensure that custom ops run in the requested precision.""" 
99+         custom_ops_to_cast , _  =  interpret_trt_plugins_precision_flag (
100+             self .model ,
101+             self .trt_plugins_precision ,
102+         )
103+         if  custom_ops_to_cast .get ("fp16" , {}):
104+             self .model  =  cast_custom_ops (self .model , custom_ops_to_cast ["fp16" ])
105+             self .custom_ops_low_precision_nodes  =  [
106+                 n .name  for  n  in  self .model .graph .node  if  n .op_type  in  custom_ops_to_cast ["fp16" ]
107+             ]
108+             logger .info ("Ensured custom ops precision" )
109+ 
91110    def  find_custom_nodes (self ) ->  None :
92111        """Find custom nodes in the model. 
93112
0 commit comments