Skip to content

Commit f485773

Browse files
nikita-savelyevvmvafin
authored andcommitted
[OV] Add quantization for SegmentAnything model (#1384)
* Constrain datasets package version * Add support for datasets v4.0 * Update setup.py * Trigger tests * WIP * WIP * Add calibration dataset collection * Simplify dataset collection logic * Add tests * Collect only for vision encoder * Add documentation * Removed helper validation script * Fix test * Remove irrelevant commands * Fix docs * Fix docs * Fix docs * Tweak model names * Update version * Trigger tests
1 parent cbc536d commit f485773

File tree

8 files changed

+289
-19
lines changed

8 files changed

+289
-19
lines changed

docs/source/openvino/optimization.mdx

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,59 @@ Click on a ✅ to copy the command/code for the corresponding optimization case.
769769
</button>
770770
</td>
771771
</tr>
772+
<tr>
773+
<td style="text-align: center; vertical-align: middle;">feature-extraction<br>(OVSamModel)</td>
774+
<td style="text-align: center; vertical-align: middle;">–</td>
775+
<td style="text-align: center; vertical-align: middle;">
776+
<button onclick="
777+
navigator.clipboard.writeText('OVSamModel.from_pretrained(\'facebook/sam-vit-base\', quantization_config=OVPipelineQuantizationConfig(quantization_configs=dict(vision_encoder=OVWeightQuantizationConfig(bits=8)))).save_pretrained(\'save_dir\')');
778+
let m=document.getElementById('copyMsg');
779+
m.style.display='block';
780+
clearTimeout(window._copyTimeout);
781+
window._copyTimeout=setTimeout(()=>m.style.display='none', 2000);
782+
">
783+
784+
</button>
785+
</td>
786+
<td style="text-align: center; vertical-align: middle;">–</td>
787+
<td style="text-align: center; vertical-align: middle;">
788+
<button onclick="
789+
navigator.clipboard.writeText('OVSamModel.from_pretrained(\'facebook/sam-vit-base\', quantization_config=OVPipelineQuantizationConfig(quantization_configs=dict(vision_encoder=OVWeightQuantizationConfig(bits=4, dataset=\'coco\')))).save_pretrained(\'save_dir\')');
790+
let m=document.getElementById('copyMsg');
791+
m.style.display='block';
792+
clearTimeout(window._copyTimeout);
793+
window._copyTimeout=setTimeout(()=>m.style.display='none', 2000);
794+
">
795+
796+
</button>
797+
</td>
798+
<td style="text-align: center; vertical-align: middle;">–</td>
799+
<td style="text-align: center; vertical-align: middle;">-</td>
800+
<td style="text-align: center; vertical-align: middle;">
801+
<button onclick="
802+
navigator.clipboard.writeText('optimum-cli export openvino -m facebook/sam-vit-base --quant-mode int8 --dataset coco ./save_dir');
803+
let m=document.getElementById('copyMsg');
804+
m.style.display='block';
805+
clearTimeout(window._copyTimeout);
806+
window._copyTimeout=setTimeout(()=>m.style.display='none', 2000);
807+
">
808+
809+
</button>
810+
</td>
811+
<td style="text-align: center; vertical-align: middle;">
812+
<button onclick="
813+
navigator.clipboard.writeText('OVSamModel.from_pretrained(\'facebook/sam-vit-base\', quantization_config=OVPipelineQuantizationConfig(quantization_configs=dict(vision_encoder=OVQuantizationConfig(bits=8, dataset=\'coco\')))).save_pretrained(\'save_dir\')');
814+
let m=document.getElementById('copyMsg');
815+
m.style.display='block';
816+
clearTimeout(window._copyTimeout);
817+
window._copyTimeout=setTimeout(()=>m.style.display='none', 2000);
818+
">
819+
820+
</button>
821+
</td>
822+
<td style="text-align: center; vertical-align: middle;">–</td>
823+
<td style="text-align: center; vertical-align: middle;">–</td>
824+
</tr>
772825
<tr>
773826
<td style="text-align: center; vertical-align: middle;">text-to-audio<br>(OVModelForTextToSpeechSeq2Seq)</td>
774827
<td style="text-align: center; vertical-align: middle;">✅</td>

optimum/intel/openvino/configuration.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .utils import (
3232
PREDEFINED_CAUSAL_LANGUAGE_DATASETS,
3333
PREDEFINED_LANGUAGE_DATASETS,
34+
PREDEFINED_SAM_DATASETS,
3435
PREDEFINED_SD_DATASETS,
3536
PREDEFINED_SPEECH_TO_TEXT_DATASETS,
3637
PREDEFINED_VISUAL_LM_DATASETS,
@@ -361,6 +362,36 @@ class OVQuantizationMethod(str, Enum):
361362

362363
# Default configs for int8 full quantization
363364
_DEFAULT_INT8_FQ_CONFIGS = {
365+
"facebook/sam-vit-base": {
366+
"quantization_configs": {
367+
"vision_encoder": {
368+
"dtype": "int8",
369+
"dataset": "coco",
370+
"num_samples": 128,
371+
"weight_only": False,
372+
},
373+
}
374+
},
375+
"facebook/sam-vit-large": {
376+
"quantization_configs": {
377+
"vision_encoder": {
378+
"dtype": "int8",
379+
"dataset": "coco",
380+
"num_samples": 128,
381+
"weight_only": False,
382+
},
383+
}
384+
},
385+
"facebook/sam-vit-huge": {
386+
"quantization_configs": {
387+
"vision_encoder": {
388+
"dtype": "int8",
389+
"dataset": "coco",
390+
"num_samples": 128,
391+
"weight_only": False,
392+
},
393+
}
394+
},
364395
"google-t5/t5-small": {
365396
"dtype": "int8",
366397
"dataset": "wikitext2",
@@ -723,16 +754,19 @@ def post_init(self):
723754
visual_lm_datasets = set(PREDEFINED_VISUAL_LM_DATASETS.keys())
724755
stable_diffusion_datasets = set(PREDEFINED_SD_DATASETS.keys())
725756
language_datasets = set(PREDEFINED_LANGUAGE_DATASETS.keys())
757+
sam_datasets = set(PREDEFINED_SAM_DATASETS.keys())
726758
if (
727759
self.dataset
728760
not in PREDEFINED_CAUSAL_LANGUAGE_DATASETS
729761
| language_datasets
730762
| visual_lm_datasets
731763
| stable_diffusion_datasets
764+
| sam_datasets
732765
):
733766
raise ValueError(
734767
"You have entered a string value for dataset. You can only choose between "
735768
f"{language_datasets} for text feature extraction models, "
769+
f"{sam_datasets} for SegmentAnything models, "
736770
f"{PREDEFINED_CAUSAL_LANGUAGE_DATASETS} for LLMs, {visual_lm_datasets} for visual LLMs or "
737771
f"{stable_diffusion_datasets} for diffusion models, but we found {self.dataset}."
738772
)
@@ -992,19 +1026,22 @@ def post_init(self):
9921026
visual_lm_datasets = set(PREDEFINED_VISUAL_LM_DATASETS.keys())
9931027
stable_diffusion_datasets = set(PREDEFINED_SD_DATASETS.keys())
9941028
language_datasets = set(PREDEFINED_LANGUAGE_DATASETS.keys())
1029+
sam_datasets = set(PREDEFINED_SAM_DATASETS.keys())
9951030
if (
9961031
self.dataset
9971032
not in PREDEFINED_CAUSAL_LANGUAGE_DATASETS
9981033
| language_datasets
9991034
| speech_to_text_datasets
10001035
| stable_diffusion_datasets
10011036
| visual_lm_datasets
1037+
| sam_datasets
10021038
):
10031039
raise ValueError(
10041040
"You can only choose between the following datasets:"
10051041
f"{language_datasets} for text feature extraction models, "
10061042
f"{PREDEFINED_CAUSAL_LANGUAGE_DATASETS} for LLMs, "
10071043
f"{speech_to_text_datasets} for speech-to-text models, "
1044+
f"{sam_datasets} for SegmentAnything models, "
10081045
f"{visual_lm_datasets} for visual LLMs or "
10091046
f"{stable_diffusion_datasets} for diffusion models, but we found {self.dataset}."
10101047
)

optimum/intel/openvino/modeling_sam.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from transformers.models.sam.modeling_sam import SamImageSegmentationOutput, SamPositionalEmbedding
1515

1616
from ...exporters.openvino.utils import save_config
17+
from .. import OVConfig
18+
from .configuration import OVQuantizationConfigBase
1719
from .modeling_base import OVBaseModel, OVModelPart
1820
from .utils import (
1921
ONNX_PROMPT_ENCODER_MASK_DECODER_MODEL_NAME,
@@ -83,19 +85,23 @@ def __init__(
8385
dynamic_shapes: bool = True,
8486
ov_config: Optional[Dict[str, str]] = None,
8587
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
88+
quantization_config: Union[OVQuantizationConfigBase, Dict] = None,
8689
**kwargs,
8790
):
8891
self.config = config
8992
self._model_save_dir = model_save_dir
9093
self._device = device.upper()
9194
self.ov_config = {} if ov_config is None else {**ov_config}
9295
self.preprocessors = kwargs.get("preprocessors", [])
93-
self.vision_encoder_model = vision_encoder_model
94-
self.prompt_encoder_mask_decoder_model = prompt_encoder_mask_decoder_model
9596
self._compile_only = kwargs.get("compile_only", False)
97+
98+
self._openvino_config = None
99+
if quantization_config:
100+
self._openvino_config = OVConfig(quantization_config=quantization_config)
101+
96102
enable_compilation = kwargs.get("compile", True)
97-
self.vision_encoder = OVSamVisionEncoder(self.vision_encoder_model, self)
98-
self.prompt_encoder_mask_decoder = OVSamPromptEncoder(self.prompt_encoder_mask_decoder_model, self)
103+
self.vision_encoder = OVSamVisionEncoder(vision_encoder_model, self)
104+
self.prompt_encoder_mask_decoder = OVSamPromptEncoder(prompt_encoder_mask_decoder_model, self)
99105

100106
if dynamic_shapes and not self.is_dynamic and not self._compile_only:
101107
self.reshape()
@@ -117,9 +123,8 @@ def clear_requests(self):
117123
raise ValueError(
118124
"`clear_requests()` is not supported with `compile_only` mode, please initialize model without this option"
119125
)
120-
121-
for _, component in self.components.items():
122-
component.clear_requests()
126+
self.vision_encoder.clear_requests()
127+
self.prompt_encoder_mask_decoder.clear_requests()
123128

124129
def compile(self):
125130
self.vision_encoder._compile()
@@ -143,15 +148,16 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
143148
"""
144149
src_models = self.ov_submodels
145150
dst_file_names = {
146-
"vision_encoder_model": OV_VISION_ENCODER_MODEL_NAME,
147-
"prompt_encoder_mask_decoder_model": OV_PROMPT_ENCODER_MASK_DECODER_MODEL_NAME,
151+
"vision_encoder": OV_VISION_ENCODER_MODEL_NAME,
152+
"prompt_encoder_mask_decoder": OV_PROMPT_ENCODER_MASK_DECODER_MODEL_NAME,
148153
}
149154

150155
for name in self._ov_submodel_names:
151156
model = src_models[name]
152157
dst_file_name = dst_file_names[name]
153158
dst_path = os.path.join(save_directory, dst_file_name)
154159
ov.save_model(model, dst_path, compress_to_fp16=False)
160+
self._save_openvino_config(save_directory)
155161

156162
@classmethod
157163
def _from_pretrained(
@@ -166,6 +172,8 @@ def _from_pretrained(
166172
vision_encoder_file_name: Optional[str] = None,
167173
prompt_encoder_mask_decoder_file_name: Optional[str] = None,
168174
local_files_only: bool = False,
175+
load_in_8bit: bool = False,
176+
quantization_config: Union[OVQuantizationConfigBase, Dict] = None,
169177
**kwargs,
170178
):
171179
"""
@@ -198,6 +206,10 @@ def _from_pretrained(
198206
openvino_prompt_encoder_mask_decoder.xml, allowing to load the decoder model with a different name.
199207
local_files_only(`bool`, *optional*, defaults to `False`):
200208
Whether or not to only look at local files (i.e., do not try to download the model).
209+
load_in_8bit(`bool`, *optional*, defaults to `False`):
210+
Whether or not to apply 8-bit weight quantization.
211+
quantization_config(`Union[OVQuantizationConfigBase, Dict]`, *optional*, defaults to `None`):
212+
Quantization configuration to apply to the model.
201213
"""
202214
if use_auth_token is not None:
203215
warnings.warn(
@@ -272,21 +284,50 @@ def _from_pretrained(
272284
model_save_dir,
273285
)
274286

287+
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
275288
model = cls(
276289
vision_encoder_model=vision_encoder_model,
277290
prompt_encoder_mask_decoder_model=prompt_encoder_model,
278291
config=config,
279292
model_save_dir=model_save_dir,
293+
quantization_config=quantization_config,
280294
**kwargs,
281295
)
282296

297+
if quantization_config is not None:
298+
from optimum.intel import OVQuantizer
299+
300+
quantizer = OVQuantizer(model)
301+
quantization_config_copy = quantization_config.clone()
302+
quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id
303+
quantization_config_copy.processor = quantization_config.processor or model_id
304+
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config_copy))
305+
283306
return model
284307

285308
@property
286309
def _ov_submodel_names(self):
287-
model_names = ["vision_encoder_model", "prompt_encoder_mask_decoder_model"]
310+
model_names = ["vision_encoder", "prompt_encoder_mask_decoder"]
288311
return model_names
289312

313+
@property
314+
def ov_submodels(self) -> Dict[str, ov.Model]:
315+
return {component_name: getattr(self, component_name).model for component_name in self._ov_submodel_names}
316+
317+
@property
318+
def vision_encoder_model(self) -> ov.Model:
319+
logger.warning(
320+
"Access to the `vision_encoder_model` attribute is deprecated and will be removed in optimum-intel v1.26, please use `vision_encoder.model` instead"
321+
)
322+
return self.vision_encoder.model
323+
324+
@property
325+
def prompt_encoder_mask_decoder_model(self) -> ov.Model:
326+
logger.warning(
327+
"Access to the `prompt_encoder_mask_decoder_model` attribute is deprecated and will be removed in optimum-intel v1.26, please use `prompt_encoder_mask_decoder.model` instead"
328+
)
329+
return self.prompt_encoder_mask_decoder.model
330+
290331
def reshape(self, batch_size: int = -1, point_batch_size: int = -1, num_points_per_image: int = -1):
291332
"""
292333
Propagates the given input shapes on the model's layers, fixing the inputs shapes of the model.
@@ -304,19 +345,19 @@ def reshape(self, batch_size: int = -1, point_batch_size: int = -1, num_points_p
304345
"`reshape()` is not supported with `compile_only` mode, please initialize model without this option"
305346
)
306347
vision_encoder_shapes = {}
307-
for inputs in self.vision_encoder_model.inputs:
348+
for inputs in self.vision_encoder.model.inputs:
308349
vision_encoder_shapes[inputs] = inputs.get_partial_shape()
309350
vision_encoder_shapes[inputs][0] = batch_size
310-
self.vision_encoder_model.reshape(vision_encoder_shapes)
351+
self.vision_encoder.model.reshape(vision_encoder_shapes)
311352
self.vision_encoder.request = None
312353
mask_decoder_shapes = {}
313-
for inputs in self.prompt_encoder_mask_decoder_model.inputs:
354+
for inputs in self.prompt_encoder_mask_decoder.model.inputs:
314355
mask_decoder_shapes[inputs] = inputs.get_partial_shape()
315356
mask_decoder_shapes[inputs][0] = batch_size
316357
if inputs.get_any_name() in ["input_points", "input_labels"]:
317358
mask_decoder_shapes[inputs][1] = point_batch_size
318359
mask_decoder_shapes[inputs][2] = num_points_per_image
319-
self.prompt_encoder_mask_decoder_model.reshape(mask_decoder_shapes)
360+
self.prompt_encoder_mask_decoder.model.reshape(mask_decoder_shapes)
320361
self.prompt_encoder_mask_decoder.request = None
321362
return self
322363

@@ -398,6 +439,6 @@ def get_image_features(self, pixel_values, *args, **kwargs):
398439

399440
@property
400441
def is_dynamic(self):
401-
return model_has_dynamic_inputs(self.vision_encoder_model) or model_has_dynamic_inputs(
402-
self.prompt_encoder_mask_decoder_model
442+
return model_has_dynamic_inputs(self.vision_encoder.model) or model_has_dynamic_inputs(
443+
self.prompt_encoder_mask_decoder.model
403444
)

0 commit comments

Comments
 (0)