Skip to content

Commit 869cb14

Browse files
authored
Enable vision encoder small op fusions (#3377)
1 parent 8d233ea commit 869cb14

File tree

3 files changed

+153
-4
lines changed

3 files changed

+153
-4
lines changed

intel_extension_for_pytorch/transformers/models/cpu/modules/decoder.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
_IPEXlinearReluCPU,
77
_IPEXlinearGeluCPU,
88
_IPEXlinearMulCPU,
9+
_IPEXlinearSiluCPU,
910
_IPEXlinearSiluMulCPU,
1011
)
1112

@@ -117,3 +118,32 @@ def __init__(self, module, config, tpp=False, woq=False):
117118
)
118119
else:
119120
AssertionError(False, "Do not support the optimization of your model yet")
121+
122+
123+
class _IPEXEncoderLayerCPU(nn.Module):
124+
def __init__(self, module, config, tpp=False, woq=False):
125+
super().__init__()
126+
for k, v in module.__dict__.items():
127+
setattr(self, k, v)
128+
for k, v in module.__class__.__dict__.items():
129+
if k.startswith("__"):
130+
continue
131+
setattr(self.__class__, k, getattr(module.__class__, k))
132+
if self.model_backbone in [
133+
"MllamaForConditionalGeneration",
134+
]:
135+
if not self.distributed:
136+
if hasattr(module, "mlp_linear_add"):
137+
self.mlp_linear_add = _IPEXlinearAddCPU(
138+
module.mlp_linear_add.linear, tpp=tpp, woq=woq
139+
)
140+
if hasattr(module, "mlp_linear_mul"):
141+
self.mlp_linear_mul = _IPEXlinearMulCPU(
142+
module.mlp_linear_mul.linear, tpp=tpp, woq=woq
143+
)
144+
if hasattr(module, "linear_silu"):
145+
self.linear_silu = _IPEXlinearSiluCPU(
146+
module.linear_silu.linear, tpp=tpp, woq=woq
147+
)
148+
else:
149+
AssertionError(False, "Do not support the optimization of your model yet")

intel_extension_for_pytorch/transformers/models/reference/modules/decoder.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
_IPEXlinearReluRef,
99
_IPEXlinearGeluRef,
1010
_IPEXlinearMulRef,
11+
_IPEXlinearSiluRef,
1112
_IPEXlinearSiluMulRef,
1213
)
14+
from .....llm.functional.fusions import add_layer_norm
1315
from torch.nn import functional as F
1416
from .....utils._logger import logger, WarningType
1517

@@ -65,6 +67,54 @@ def LlamaDecoderLayer_forward(
6567
return outputs
6668

6769

70+
def MllamaVisionEncoderLayer_forward(
71+
self,
72+
hidden_state: torch.Tensor,
73+
attention_mask: Optional[torch.Tensor] = None,
74+
output_attentions: bool = None,
75+
):
76+
# Self Attention
77+
residual = hidden_state
78+
hidden_state = self.input_layernorm(hidden_state)
79+
hidden_state, attn_weights = self.self_attn(
80+
hidden_state, attention_mask=attention_mask
81+
)
82+
if self.is_gated:
83+
hidden_state = self.gate_attn.tanh() * hidden_state
84+
85+
hidden_state = add_layer_norm(
86+
residual,
87+
hidden_state,
88+
self.post_attention_layernorm.weight,
89+
self.post_attention_layernorm.bias,
90+
self.post_attention_layernorm.eps,
91+
True,
92+
)
93+
94+
hidden_states = self.self.linear_silu(hidden_states)
95+
96+
if self.is_gated:
97+
if self.distributed:
98+
hidden_states = self.mlp.fc2(hidden_states)
99+
hidden_state = self.gate_ffn.tanh() * hidden_state
100+
else:
101+
hidden_state = self.mlp_linear_mul(hidden_state, self.gate_ffn.tanh())
102+
hidden_state = residual + hidden_state
103+
else:
104+
if self.distributed:
105+
hidden_states = self.mlp.fc2(hidden_states)
106+
hidden_state = residual + hidden_state
107+
else:
108+
hidden_state = self.mlp_linear_add(hidden_state, residual)
109+
110+
outputs = (hidden_state,)
111+
112+
if output_attentions:
113+
outputs += (attn_weights,)
114+
115+
return outputs
116+
117+
68118
def OPTDecoderLayer_forward(
69119
self,
70120
hidden_states: torch.Tensor,
@@ -2091,3 +2141,45 @@ def forward(
20912141
)
20922142
else:
20932143
AssertionError(False, "Do not support the optimization of your model yet")
2144+
2145+
2146+
class _IPEXEncoderLayerRef(nn.Module):
2147+
def __init__(self, module, config, distributed=False):
2148+
super().__init__()
2149+
for k, v in module.__dict__.items():
2150+
setattr(self, k, v)
2151+
for k, v in module.__class__.__dict__.items():
2152+
if k.startswith("__") or k.startswith("forward"):
2153+
continue
2154+
setattr(self.__class__, k, getattr(module.__class__, k))
2155+
self.distributed = distributed
2156+
self.model_backbone = config.architectures[0]
2157+
if self.model_backbone in [
2158+
"MllamaForConditionalGeneration",
2159+
]:
2160+
if not self.distributed:
2161+
if self.is_gated:
2162+
self.mlp_linear_mul = _IPEXlinearMulRef(module.mlp.fc2)
2163+
else:
2164+
self.mlp_linear_add = _IPEXlinearAddRef(module.mlp.fc2)
2165+
del self.__dict__["_modules"]["mlp"].fc2
2166+
self.linear_silu = _IPEXlinearSiluRef(module.mlp.fc1)
2167+
del self.__dict__["_modules"]["mlp"].fc1
2168+
else:
2169+
AssertionError(False, "Do not support the optimization of your model yet")
2170+
2171+
def forward(
2172+
self,
2173+
hidden_state: torch.Tensor,
2174+
attention_mask: Optional[torch.Tensor] = None,
2175+
output_attentions: bool = None,
2176+
):
2177+
if self.model_backbone == "MllamaForConditionalGeneration":
2178+
return MllamaVisionEncoderLayer_forward(
2179+
self,
2180+
hidden_state,
2181+
attention_mask,
2182+
output_attentions,
2183+
)
2184+
else:
2185+
AssertionError(False, "Do not support the optimization of your model yet")

intel_extension_for_pytorch/transformers/optimize.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def model_convert_reference(_model):
150150
)
151151

152152
# model wise optimization for Feedforward and Decoder layer modules
153-
from .models.reference.modules.decoder import _IPEXDecoderLayerRef
153+
from .models.reference.modules.decoder import (
154+
_IPEXDecoderLayerRef,
155+
_IPEXEncoderLayerRef,
156+
)
154157

155158
# generation length or model forward order
156159
from .models.reference.models import (
@@ -572,6 +575,16 @@ def model_convert_reference(_model):
572575
_model.config,
573576
distributed=distributed,
574577
)
578+
for supported_encoder_class in [
579+
transformers.models.mllama.modeling_mllama.MllamaVisionEncoderLayer
580+
]:
581+
convert_class(
582+
_model,
583+
supported_encoder_class,
584+
_IPEXEncoderLayerRef,
585+
_model.config,
586+
distributed=distributed,
587+
)
575588
# special list that has not official transformers design
576589
if _model.config.architectures[0] == "BloomForCausalLM":
577590
convert_function(
@@ -1374,12 +1387,18 @@ def model_convert_lowering(
13741387
cache_weight_for_large_batch=False,
13751388
):
13761389
from .models.reference.modules.attentions import _IPEXAttentionRef
1377-
from .models.reference.modules.decoder import _IPEXDecoderLayerRef
1390+
from .models.reference.modules.decoder import (
1391+
_IPEXDecoderLayerRef,
1392+
_IPEXEncoderLayerRef,
1393+
)
13781394

13791395
if device == "cpu":
13801396
from .models.cpu.modules.attentions import _IPEXAttentionCPU
13811397
from .models.cpu.fusions.mha_fusion import _IPEXRMSNormCPU
1382-
from .models.cpu.modules.decoder import _IPEXDecoderLayerCPU
1398+
from .models.cpu.modules.decoder import (
1399+
_IPEXDecoderLayerCPU,
1400+
_IPEXEncoderLayerCPU,
1401+
)
13831402

13841403
_disable_tpp()
13851404
if not is_quantization:
@@ -1479,7 +1498,15 @@ def model_convert_lowering(
14791498
tpp=True if _using_tpp() else False,
14801499
woq=woq,
14811500
)
1482-
1501+
for supported_mlp_class in [_IPEXEncoderLayerRef]:
1502+
lowering_class_cpu(
1503+
_model,
1504+
supported_mlp_class,
1505+
_IPEXEncoderLayerCPU,
1506+
_model.config,
1507+
tpp=True if _using_tpp() else False,
1508+
woq=woq,
1509+
)
14831510
for supported_mha_class in [_IPEXAttentionRef]:
14841511
lowering_class_cpu(
14851512
_model,

0 commit comments

Comments
 (0)