@@ -106,13 +106,14 @@ def export_image_encoder(
106106def export_onnx_model (
107107 model_type : str ,
108108 output_root : Union [str , os .PathLike ],
109- opset : int ,
109+ opset : int = 17 ,
110110 export_name : Optional [str ] = None ,
111111 checkpoint_path : Optional [Union [str , os .PathLike ]] = None ,
112112 return_single_mask : bool = True ,
113113 gelu_approximate : bool = False ,
114114 use_stability_score : bool = False ,
115115 return_extra_metrics : bool = False ,
116+ quantize_model : bool = False ,
116117) -> None :
117118 """Export SAM prompt encoder and mask decoder to onnx.
118119
@@ -123,14 +124,16 @@ def export_onnx_model(
123124 Args:
124125 model_type: The SAM model type.
125126 output_root: The output root directory where the exported model is saved.
126- opset: The ONNX opset version.
127+ opset: The ONNX opset version. The recommended opset version is 17.
127128 export_name: The name of the exported model.
128129 checkpoint_path: Optional checkpoint for loading the SAM model.
129130 return_single_mask: Whether the mask decoder returns a single or multiple masks.
130131 gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
131132 does not have an efficient GeLU implementation.
132133 use_stability_score: Whether to use the stability score instead of the predicted score.
133134 return_extra_metrics: Whether to return a larger set of metrics.
135+ quantize_model: Whether to also export a quantized version of the model.
136+ This only works for onnxruntime < 1.17.
134137 """
135138 if export_name is None :
136139 export_name = model_type
@@ -155,9 +158,7 @@ def export_onnx_model(
155158 if isinstance (m , torch .nn .GELU ):
156159 m .approximate = "tanh"
157160
158- dynamic_axes = {
159- "point_coords" : {1 : "num_points" }, "point_labels" : {1 : "num_points" },
160- }
161+ dynamic_axes = {"point_coords" : {1 : "num_points" }, "point_labels" : {1 : "num_points" }}
161162
162163 embed_dim = sam .prompt_encoder .embed_dim
163164 embed_size = sam .prompt_encoder .image_embedding_size
@@ -202,6 +203,23 @@ def export_onnx_model(
202203 _ = ort_session .run (None , ort_inputs )
203204 print ("Model has successfully been run with ONNXRuntime." )
204205
206+ # This requires onnxruntime < 1.17.
207+ # See https://github.com/facebookresearch/segment-anything/issues/699#issuecomment-1984670808
208+ if quantize_model :
209+ assert onnxruntime_exists
210+ from onnxruntime .quantization import QuantType
211+ from onnxruntime .quantization .quantize import quantize_dynamic
212+
213+ quantized_path = os .path .join (weight_output_folder , "model_quantized.onnx" )
214+ quantize_dynamic (
215+ model_input = weight_path ,
216+ model_output = quantized_path ,
217+ # optimize_model=True,
218+ per_channel = False ,
219+ reduce_range = False ,
220+ weight_type = QuantType .QUInt8 ,
221+ )
222+
205223 config_output_path = os .path .join (output_folder , "config.pbtxt" )
206224 with open (config_output_path , "w" ) as f :
207225 f .write (DECODER_CONFIG % name )
0 commit comments