1414
1515from transformers .models .bert .modeling_bert import BertIntermediate
1616from transformers .models .falcon .modeling_falcon import FalconDecoderLayer , FalconModel
17- from transformers .models .gpt2 .modeling_gpt2 import GPT2MLP , GPT2Attention , GPT2Block , GPT2Model
17+ from transformers .models .gpt2 .modeling_gpt2 import GPT2Block , GPT2Model
1818from transformers .models .llama .modeling_llama import (
1919 LlamaDecoderLayer ,
2020 LlamaModel ,
3232
3333from .modeling_utils import (
3434 _IPEX_MINIMUM_VERSION_FOR_PATCHING ,
35- _IPEXGPT2MLP ,
3635 _falcon_model_forward ,
37- _gpt2_block_forward ,
3836 _gpt2_model_forward ,
3937 _ipex_rms_layer_norm_forward ,
4038 _IPEXFalconDecoderLayer ,
41- _IPEXGPT2Attention ,
39+ _IPEXGPT2Block ,
4240 _IPEXIntermediate ,
4341 _IPEXLlamaDecoderLayer ,
4442 _IPEXQwen2DecoderLayer ,
@@ -66,12 +64,12 @@ def convert_functions(m, target_m, new_function_name, new_function):
6664 convert_functions (sub_m , target_m , new_function_name , new_function )
6765
6866
69- def convert_class (m , target_m , new_class , config = None ):
67+ def convert_class (m , target_m , new_class , device , config ):
7068 for name , sub_m in m .named_children ():
7169 if isinstance (sub_m , target_m ):
72- new_m = new_class (sub_m , config )
70+ new_m = new_class (sub_m , device , config )
7371 setattr (m , name , new_m )
74- convert_class (sub_m , target_m , new_class , config )
72+ convert_class (sub_m , target_m , new_class , device , config )
7573
7674
7775def patch_op (m , target_m , new_op_name , new_op ):
@@ -89,7 +87,7 @@ def _patch_llama_model(model):
8987 """
9088 convert_functions (model , LlamaModel , "forward" , _llama_model_forward )
9189 convert_functions (model , LlamaRMSNorm , "forward" , _ipex_rms_layer_norm_forward )
92- convert_class (model , LlamaDecoderLayer , _IPEXLlamaDecoderLayer , model .config )
90+ convert_class (model , LlamaDecoderLayer , _IPEXLlamaDecoderLayer , model .device , model . config )
9391 return model
9492
9593
@@ -105,21 +103,20 @@ def _patch_falcon_model(model):
105103 setattr (model .config , "num_key_value_heads" , num_key_value_heads )
106104 convert_functions (model , FalconModel , "forward" , _falcon_model_forward )
107105 replace_customized_linear_with_linear (model )
108- convert_class (model , FalconDecoderLayer , _IPEXFalconDecoderLayer , model .config )
106+ convert_class (model , FalconDecoderLayer , _IPEXFalconDecoderLayer , model .device , model . config )
109107 return model
110108
111109
112110def _patch_gpt2_model (model ):
113111 """
114112 Patch gpt2 model:
115113 1. Use IPEX paged attention
114+ 2. Linear fusion with (Linear + Add)
116115 """
117116 num_key_value_heads = model .config .num_attention_heads
118117 setattr (model .config , "num_key_value_heads" , num_key_value_heads )
119118 convert_functions (model , GPT2Model , "forward" , _gpt2_model_forward )
120- convert_functions (model , GPT2Block , "forward" , _gpt2_block_forward )
121- convert_class (model , GPT2Attention , _IPEXGPT2Attention , model .config )
122- convert_class (model , GPT2MLP , _IPEXGPT2MLP , model .config )
119+ convert_class (model , GPT2Block , _IPEXGPT2Block , model .device , model .config )
123120 return model
124121
125122
@@ -131,7 +128,7 @@ def _patch_qwen2_model(model):
131128 """
132129 convert_functions (model , Qwen2Model , "forward" , _qwen2_model_forward )
133130 convert_functions (model , Qwen2RMSNorm , "forward" , _ipex_rms_layer_norm_forward )
134- convert_class (model , Qwen2DecoderLayer , _IPEXQwen2DecoderLayer , model .config )
131+ convert_class (model , Qwen2DecoderLayer , _IPEXQwen2DecoderLayer , model .device , model . config )
135132 return model
136133
137134
@@ -140,7 +137,7 @@ def _patch_bert_model(model):
140137 Patch bert model:
141138 1. Linear fusion with Linear + Gelu
142139 """
143- convert_class (model , BertIntermediate , _IPEXIntermediate )
140+ convert_class (model , BertIntermediate , _IPEXIntermediate , model . device , model . config )
144141 return model
145142
146143
@@ -149,7 +146,7 @@ def _patch_vit_model(model):
149146 Patch vit model:
150147 1. Linear fusion with Linear + Gelu
151148 """
152- convert_class (model , ViTIntermediate , _IPEXIntermediate )
149+ convert_class (model , ViTIntermediate , _IPEXIntermediate , model . device , model . config )
153150 return model
154151
155152
0 commit comments