Skip to content

Commit 208b2ef

Browse files
Add ignored pattern for Positional Embedding from Segment Anything model (#3700)
### Changes As in the title. Example model: https://huggingface.co/facebook/sam-vit-base . Code: https://github.com/facebookresearch/segment-anything/blob/dca509fe793f601edb92606367a655c15ac00fdf/segment_anything/modeling/prompt_encoder.py#L171 . ### Reason for changes To automatically ignore positional embedding weight during weights compression. <img width="230" height="471" alt="image" src="https://github.com/user-attachments/assets/f0ff351b-a426-4076-b339-d883a04b4451" /> ### Related tickets 175083 ### Tests Added tests/cross_fw/test_templates/template_test_weights_compression.py::test_sam_pe_weight_compression
1 parent c7eac4d commit 208b2ef

File tree

16 files changed

+216
-4
lines changed

16 files changed

+216
-4
lines changed

src/nncf/common/graph/patterns/patterns.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,3 +408,4 @@ class IgnoredPatternNames(Enum):
408408
FC_BN_HSWISH_ACTIVATION = PatternDesc("fc_bn_hswish_activation")
409409
EQUAL_LOGICALNOT = PatternDesc("equal_logicalnot")
410410
ROPE = PatternDesc("rope", model_types=[ModelType.TRANSFORMER])
411+
SAM_PE = PatternDesc("sam_pe", model_types=[ModelType.TRANSFORMER])

src/nncf/onnx/quantization/ignored_patterns.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,29 @@ def create_rope() -> GraphPattern:
179179
pattern.add_edge(concat_node, cos_node)
180180
pattern.add_edge(concat_node, sin_node)
181181
return pattern
182+
183+
184+
@ONNX_IGNORED_PATTERNS.register(IgnoredPatternNames.SAM_PE)
185+
def create_sam_pe() -> GraphPattern:
186+
"""
187+
Positional Embedding from Segment Anything Model (SAM).
188+
"""
189+
pattern = GraphPattern()
190+
191+
matmul_node = pattern.add_node(
192+
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.ONNXMatMulMetatype}
193+
)
194+
mul_node = pattern.add_node(
195+
**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.ONNXMulLayerMetatype}
196+
)
197+
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.ONNXCosMetatype})
198+
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.ONNXSinMetatype})
199+
concat = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.ONNXConcatMetatype})
200+
201+
pattern.add_edge(matmul_node, mul_node)
202+
pattern.add_edge(mul_node, cos_node)
203+
pattern.add_edge(mul_node, sin_node)
204+
pattern.add_edge(cos_node, concat)
205+
pattern.add_edge(sin_node, concat)
206+
207+
return pattern

src/nncf/openvino/quantization/ignored_patterns.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,29 @@ def create_rope() -> GraphPattern:
186186
pattern.add_edge(concat_node, cos_node)
187187
pattern.add_edge(concat_node, sin_node)
188188
return pattern
189+
190+
191+
@OPENVINO_IGNORED_PATTERNS.register(IgnoredPatternNames.SAM_PE)
192+
def create_sam_pe() -> GraphPattern:
193+
"""
194+
Positional Embedding from Segment Anything Model (SAM).
195+
"""
196+
pattern = GraphPattern()
197+
198+
matmul_node = pattern.add_node(
199+
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.OVMatMulMetatype}
200+
)
201+
mul_node = pattern.add_node(
202+
**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype}
203+
)
204+
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.OVCosMetatype})
205+
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.OVSinMetatype})
206+
concat = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.OVConcatMetatype})
207+
208+
pattern.add_edge(matmul_node, mul_node)
209+
pattern.add_edge(mul_node, cos_node)
210+
pattern.add_edge(mul_node, sin_node)
211+
pattern.add_edge(cos_node, concat)
212+
pattern.add_edge(sin_node, concat)
213+
214+
return pattern

src/nncf/quantization/algorithms/weight_compression/onnx_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from nncf.onnx.graph.transformations.command_creation import ONNXCommandCreator
5050
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
5151
from nncf.onnx.quantization.ignored_patterns import create_rope
52+
from nncf.onnx.quantization.ignored_patterns import create_sam_pe
5253
from nncf.parameters import CompressionFormat
5354
from nncf.parameters import CompressWeightsMode
5455
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
@@ -491,7 +492,9 @@ def _replace_matmul_with_matmulnbits(
491492

492493
@staticmethod
493494
def get_ignored_patterns() -> GraphPattern:
494-
return create_rope()
495+
pattern = create_rope()
496+
pattern.add_pattern_alternative(create_sam_pe())
497+
return pattern
495498

496499

497500
class ONNXAWQAlgoAlgoBackend(AWQAlgoBackend, ONNXWeightCompressionAlgoBackend):

src/nncf/quantization/algorithms/weight_compression/openvino_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from nncf.openvino.optimized_functions import clear_ov_model_cache
4444
from nncf.openvino.optimized_functions.models import OV_MODEL_CACHE
4545
from nncf.openvino.quantization.ignored_patterns import create_rope
46+
from nncf.openvino.quantization.ignored_patterns import create_sam_pe
4647
from nncf.openvino.rt_info import dump_parameters
4748
from nncf.openvino.statistics.collectors import OVMaxVarianceReducer
4849
from nncf.openvino.statistics.collectors import OVMeanAbsMaxReducer
@@ -394,7 +395,9 @@ def filter_func(point: StatisticPoint) -> bool:
394395

395396
@staticmethod
396397
def get_ignored_patterns() -> GraphPattern:
397-
return create_rope()
398+
pattern = create_rope()
399+
pattern.add_pattern_alternative(create_sam_pe())
400+
return pattern
398401

399402

400403
class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend):

src/nncf/quantization/algorithms/weight_compression/torch_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from nncf.torch.model_transformer import PTModelTransformer
6363
from nncf.torch.nncf_network import NNCFNetwork
6464
from nncf.torch.quantization.ignored_patterns import create_rope
65+
from nncf.torch.quantization.ignored_patterns import create_sam_pe
6566
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
6667
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
6768
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
@@ -481,7 +482,9 @@ def transform_model(
481482

482483
@staticmethod
483484
def get_ignored_patterns() -> GraphPattern:
484-
return create_rope()
485+
pattern = create_rope()
486+
pattern.add_pattern_alternative(create_sam_pe())
487+
return pattern
485488

486489

487490
class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend):

src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from nncf.torch.model_graph_manager import get_weight_compression_reduction_axes
5858
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
5959
from nncf.torch.quantization.ignored_patterns import create_rope
60+
from nncf.torch.quantization.ignored_patterns import create_sam_pe
6061
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
6162
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
6263
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
@@ -257,7 +258,9 @@ def transform_model(
257258

258259
@staticmethod
259260
def get_ignored_patterns() -> GraphPattern:
260-
return create_rope()
261+
pattern = create_rope()
262+
pattern.add_pattern_alternative(create_sam_pe())
263+
return pattern
261264

262265

263266
class FXMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, FXWeightCompressionAlgoBackend):

src/nncf/torch/quantization/ignored_patterns.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,27 @@ def create_rope() -> GraphPattern:
250250
pattern.add_edge(concat_node, cos_node)
251251
pattern.add_edge(concat_node, sin_node)
252252
return pattern
253+
254+
255+
@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.SAM_PE)
256+
def create_sam_pe() -> GraphPattern:
257+
"""
258+
Positional Embedding from Segment Anything Model (SAM).
259+
"""
260+
pattern = GraphPattern()
261+
262+
matmul_node = pattern.add_node(
263+
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.PTMatMulMetatype}
264+
)
265+
mul_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.PTMulMetatype})
266+
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.PTCosMetatype})
267+
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.PTSinMetatype})
268+
concat = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.PTCatMetatype})
269+
270+
pattern.add_edge(matmul_node, mul_node)
271+
pattern.add_edge(mul_node, cos_node)
272+
pattern.add_edge(mul_node, sin_node)
273+
pattern.add_edge(cos_node, concat)
274+
pattern.add_edge(sin_node, concat)
275+
276+
return pattern

tests/cross_fw/test_templates/helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,24 @@ def forward(self, x):
499499
x1 = x.sin()
500500
x2 = x.cos()
501501
return x1, x2
502+
503+
504+
class SAMPEModel(nn.Module):
505+
"""
506+
Positional Embedding from Segment Anything Model (SAM).
507+
"""
508+
509+
INPUT_SIZE = [1, 2, 3, 2]
510+
511+
def __init__(self):
512+
super().__init__()
513+
with set_torch_seed():
514+
self.weight = nn.Parameter(torch.empty((2, 128)))
515+
516+
def forward(self, x):
517+
x = torch.matmul(x, self.weight)
518+
x = x * (2 * torch.pi)
519+
x1 = x.sin()
520+
x2 = x.cos()
521+
x = torch.cat([x1, x2], dim=-1)
522+
return x

tests/cross_fw/test_templates/template_test_weights_compression.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ def get_matmul_model() -> TModel:
119119
def get_RoPE_model() -> TModel:
120120
"""Returns a backend model for test_rope_weight_compression."""
121121

122+
@staticmethod
123+
@abstractmethod
124+
def get_SAM_PE_model() -> TModel:
125+
"""Returns a backend model for test_sam_pe_weight_compression."""
126+
122127
@pytest.mark.parametrize(
123128
("mode", "ref_act_score", "ref_score"),
124129
(
@@ -400,6 +405,26 @@ def test_rope_weight_compression(self):
400405
int4_num_nodes = self.get_num_int4_nodes(compressed_model)
401406
assert int4_num_nodes == int4_ref_num_compressed
402407

408+
def test_sam_pe_weight_compression(self):
409+
model = self.get_SAM_PE_model()
410+
411+
dataset = Dataset(
412+
[self.to_tensor(np.ones([1, 2, 3, 2], dtype=np.float32))],
413+
self.get_transform_func(),
414+
)
415+
compressed_model = compress_weights(
416+
model,
417+
mode=CompressWeightsMode.INT4_SYM,
418+
ratio=1.0,
419+
group_size=-1,
420+
dataset=dataset,
421+
all_layers=True,
422+
)
423+
424+
int4_ref_num_compressed = 0
425+
int4_num_nodes = self.get_num_int4_nodes(compressed_model)
426+
assert int4_num_nodes == int4_ref_num_compressed
427+
403428
@staticmethod
404429
@abstractmethod
405430
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:

0 commit comments

Comments
 (0)