diff --git a/scripts/export_onnx_model.py b/scripts/export_onnx_model.py index 740210f..959bbae 100644 --- a/scripts/export_onnx_model.py +++ b/scripts/export_onnx_model.py @@ -139,7 +139,7 @@ def run_export( embed_dim = sam.prompt_encoder.embed_dim embed_size = sam.prompt_encoder.image_embedding_size - encoder_embed_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280} + encoder_embed_dim_dict = {"vit_tiny":160,"vit_b":768,"vit_l":1024,"vit_h":1280} encoder_embed_dim = encoder_embed_dim_dict[model_type] mask_input_size = [4 * x for x in embed_size]