Skip to content

Commit bba4d82

Browse files
Add support for quantizing models in onnx export
1 parent d5d0a0a commit bba4d82

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

micro_sam/bioimageio/bioengine_export.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,14 @@ def export_image_encoder(
106106
def export_onnx_model(
107107
model_type: str,
108108
output_root: Union[str, os.PathLike],
109-
opset: int,
109+
opset: int = 17,
110110
export_name: Optional[str] = None,
111111
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
112112
return_single_mask: bool = True,
113113
gelu_approximate: bool = False,
114114
use_stability_score: bool = False,
115115
return_extra_metrics: bool = False,
116+
quantize_model: bool = False,
116117
) -> None:
117118
"""Export SAM prompt encoder and mask decoder to onnx.
118119
@@ -123,14 +124,16 @@ def export_onnx_model(
123124
Args:
124125
model_type: The SAM model type.
125126
output_root: The output root directory where the exported model is saved.
126-
opset: The ONNX opset version.
127+
opset: The ONNX opset version. The recommended opset version is 17.
127128
export_name: The name of the exported model.
128129
checkpoint_path: Optional checkpoint for loading the SAM model.
129130
return_single_mask: Whether the mask decoder returns a single or multiple masks.
130131
gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
131132
does not have an efficient GeLU implementation.
132133
use_stability_score: Whether to use the stability score instead of the predicted score.
133134
return_extra_metrics: Whether to return a larger set of metrics.
135+
quantize_model: Whether to also export a quantized version of the model.
136+
This only works for onnxruntime < 1.17.
134137
"""
135138
if export_name is None:
136139
export_name = model_type
@@ -155,9 +158,7 @@ def export_onnx_model(
155158
if isinstance(m, torch.nn.GELU):
156159
m.approximate = "tanh"
157160

158-
dynamic_axes = {
159-
"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"},
160-
}
161+
dynamic_axes = {"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}}
161162

162163
embed_dim = sam.prompt_encoder.embed_dim
163164
embed_size = sam.prompt_encoder.image_embedding_size
@@ -202,6 +203,23 @@ def export_onnx_model(
202203
_ = ort_session.run(None, ort_inputs)
203204
print("Model has successfully been run with ONNXRuntime.")
204205

206+
# This requires onnxruntime < 1.17.
207+
# See https://github.com/facebookresearch/segment-anything/issues/699#issuecomment-1984670808
208+
if quantize_model:
209+
assert onnxruntime_exists
210+
from onnxruntime.quantization import QuantType
211+
from onnxruntime.quantization.quantize import quantize_dynamic
212+
213+
quantized_path = os.path.join(weight_output_folder, "model_quantized.onnx")
214+
quantize_dynamic(
215+
model_input=weight_path,
216+
model_output=quantized_path,
217+
# optimize_model=True,
218+
per_channel=False,
219+
reduce_range=False,
220+
weight_type=QuantType.QUInt8,
221+
)
222+
205223
config_output_path = os.path.join(output_folder, "config.pbtxt")
206224
with open(config_output_path, "w") as f:
207225
f.write(DECODER_CONFIG % name)

0 commit comments

Comments
 (0)