3232
3333logger = logging .getLogger (__name__ )
3434
35- _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4 .0"
35+ _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.6 .0"
3636_accelerate_added_attributes = ["to" , "xpu" ]
3737
3838
5252 )
5353
5454
55- # TODO: Following XPULinearXXX op classes will be put into ipex after 2.6.0 version
56- class XPULinear2SiluMul (torch .nn .Module ):
57- def __init__ (
58- self ,
59- gate_proj : torch .nn .Module ,
60- up_proj : torch .nn .Module ,
61- ):
62- super ().__init__ ()
63- self .gate_proj_weight = gate_proj .weight .transpose (0 , 1 ).contiguous ()
64- self .up_proj_weight = up_proj .weight .transpose (0 , 1 ).contiguous ()
65- self .gate_proj_bias = gate_proj .bias
66- self .up_proj_bias = up_proj .bias
67-
68- def forward (
69- self ,
70- hidden_states ,
71- ):
72- up = torch .ops .torch_ipex .mm_silu (hidden_states , self .gate_proj_weight )
73- if self .gate_proj_bias is not None :
74- up += self .gate_proj_bias
75- hidden_states = torch .ops .torch_ipex .mm_resmul (hidden_states , self .up_proj_weight , up )
76- if self .up_proj_bias is not None :
77- hidden_states += self .up_proj_bias
78- return hidden_states
79-
80-
81- class XPULinearGelu (torch .nn .Module ):
82- def __init__ (self , module : torch .nn .Module ):
83- super ().__init__ ()
84- self .weight = module .weight .transpose (0 , 1 ).contiguous ()
85- self .bias = module .bias
86-
87- def forward (self , x ):
88- return torch .ops .torch_ipex .matmul_gelu (x , self .weight , self .bias , 1.0 , "tanh" )
89-
90-
91- class XPULinearAdd (torch .nn .Module ):
92- def __init__ (
93- self ,
94- module : torch .nn .Module ,
95- ):
96- super ().__init__ ()
97- self .weight = module .weight .transpose (0 , 1 ).contiguous ()
98- self .bias = module .bias
99-
100- def forward (
101- self ,
102- hidden_states ,
103- residual ,
104- ):
105- token_len , _ = hidden_states .size ()
106- if residual is None :
107- hidden_states = torch .matmul (hidden_states , self .weight )
108- if self .bias is not None :
109- hidden_states += self .bias
110- else :
111- if self .bias is not None :
112- hidden_states = torch .ops .torch_ipex .mm_bias_resadd (
113- hidden_states , self .weight , self .bias , 1.0 , residual , 1.0
114- )
115- else :
116- hidden_states = torch .addmm (
117- residual .flatten (0 , - 2 ),
118- hidden_states .flatten (0 , - 2 ),
119- self .weight ,
120- beta = 1.0 ,
121- )
122- hidden_states = hidden_states .view (token_len , - 1 )
123- return hidden_states
124-
125-
126- class XPUlinearAddAdd (torch .nn .Module ):
127- def __init__ (self , module : torch .nn .Module ):
128- super ().__init__ ()
129- self .weight = module .weight .transpose (0 , 1 ).contiguous ()
130- self .bias = module .bias
131-
132- def forward (self , x , y , z ):
133- if self .bias is not None :
134- x = torch .ops .torch_ipex .mm_bias_resadd (x , self .weight , self .bias , 1.0 , y , 1.0 )
135- x += z
136- else :
137- x = torch .ops .torch_ipex .mm_bias_resadd (x , self .weight , z , 1.0 , y , 1.0 )
138- return x
139-
140-
14155# Adapted from https://github.com/huggingface/accelerate/blob/v1.2.1/src/accelerate/hooks.py#L183
14256def _remove_hooks_for_ipex (module , recurse ):
14357 if hasattr (module , "_hf_hook" ):
@@ -885,11 +799,9 @@ def __init__(self, module, device, config) -> None:
885799 self .q_slice = self .q_proj .weight .shape [0 ]
886800 self .k_slice = self .q_slice + self .k_proj .weight .shape [0 ]
887801 self .v_slice = self .k_slice + self .v_proj .weight .shape [0 ]
802+
888803 if not config .compile and module .o_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
889- if self .module_device .type == "cpu" :
890- self .mha_linear_add = LinearAdd (module .o_proj )
891- elif self .module_device .type == "xpu" :
892- self .mha_linear_add = XPULinearAdd (module .o_proj )
804+ self .mha_linear_add = LinearAdd (module .o_proj )
893805
894806 def qkv_gemm (self , hidden_states ):
895807 if hasattr (self , "concat_qkv" ):
@@ -940,13 +852,8 @@ def __init__(self, module, device, config) -> None:
940852 self .c_proj_linear = nn .Linear (self .c_proj .weight .shape [0 ], self .c_proj .weight .shape [1 ])
941853 self .c_proj_linear .weight = nn .Parameter (self .c_proj .weight .t ())
942854 self .c_proj_linear .bias = self .c_proj .bias
943- if self .module_device .type == "cpu" :
944- if self .c_proj_linear not in ["LinearAllreduce" ]:
945- self .linear_add = LinearAdd (self .c_proj_linear )
946-
947- elif self .module_device .type == "xpu" :
948- if self .c_proj_linear not in ["LinearAllreduce" ]:
949- self .linear_add = XPULinearAdd (self .c_proj_linear )
855+ if self .c_proj_linear not in ["LinearAllreduce" ]:
856+ self .linear_add = LinearAdd (self .c_proj_linear )
950857
951858 def qkv_gemm (self , hidden_states ):
952859 if hasattr (self , "c_attn_linear" ):
@@ -977,17 +884,12 @@ def __init__(self, module, device, config) -> None:
977884 _setattr_from_module (self , module )
978885 self .config = config
979886 self .module_device = device
887+
980888 if not config .compile and getattr (config , "quantization_config" , None ) is None :
981- if self .module_device .type == "cpu" :
982- # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
983- if module .down_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
984- self .mlp_linear_add = LinearAdd (module .down_proj )
985- self .linear_silu_mul = Linear2SiluMul (module .gate_proj , module .up_proj )
986- elif self .module_device .type == "xpu" :
987- # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
988- if module .down_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
989- self .mlp_linear_add = XPULinearAdd (module .down_proj )
990- self .linear_silu_mul = XPULinear2SiluMul (module .gate_proj , module .up_proj )
889+ # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
890+ if module .down_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
891+ self .mlp_linear_add = LinearAdd (module .down_proj )
892+ self .linear_silu_mul = Linear2SiluMul (module .gate_proj , module .up_proj )
991893
992894 def forward (self , hidden_states : torch .Tensor , residual : torch .Tensor = None , ** kwargs ):
993895 if hasattr (self , "linear_silu_mul" ):
@@ -1012,15 +914,10 @@ def __init__(self, module, device, config) -> None:
1012914 self .module_device = device
1013915 if not config .compile and getattr (config , "quantization_config" , None ) is None :
1014916 # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
1015- if self .module_device .type == "cpu" :
1016- self .linear_gelu = LinearGelu (module .dense_h_to_4h )
1017- elif self .module_device .type == "xpu" :
1018- self .linear_gelu = XPULinearGelu (module .dense_h_to_4h )
917+ self .linear_gelu = LinearGelu (module .dense_h_to_4h )
918+
1019919 if module .dense_4h_to_h .__class__ .__name__ not in ["LinearAllreduce" ]:
1020- if self .module_device .type == "cpu" :
1021- self .linear_add_add = LinearAddAdd (module .dense_4h_to_h )
1022- elif self .module_device .type == "xpu" :
1023- self .linear_add_add = XPUlinearAddAdd (module .dense_4h_to_h )
920+ self .linear_add_add = LinearAddAdd (module .dense_4h_to_h )
1024921
1025922 def forward (
1026923 self ,
@@ -1059,11 +956,9 @@ def __init__(self, module, device, config) -> None:
1059956 self .c_proj_linear .bias = self .c_proj .bias
1060957 if self .module_device .type == "cpu" :
1061958 self .linear_new_gelu = LinearNewGelu (self .c_fc_linear )
1062- if self .c_proj_linear not in ["LinearAllreduce" ]:
1063- self .linear_add = LinearAdd (self .c_proj_linear )
1064- elif self .module_device .type == "xpu" :
1065- if self .c_proj_linear not in ["LinearAllreduce" ]:
1066- self .linear_add = XPULinearAdd (self .c_proj_linear )
959+
960+ if self .c_proj_linear not in ["LinearAllreduce" ]:
961+ self .linear_add = LinearAdd (self .c_proj_linear )
1067962
1068963 def forward (self , hidden_states : Optional [Tuple [torch .FloatTensor ]]) -> torch .FloatTensor :
1069964 if hasattr (self , "linear_new_gelu" ):
@@ -1232,11 +1127,9 @@ def __init__(self, module, device, config):
12321127 super ().__init__ ()
12331128 _setattr_from_module (self , module )
12341129 self .module_device = device
1130+
12351131 if not config .compile and getattr (config , "quantization_config" , None ) is None :
1236- if self .module_device .type == "cpu" :
1237- self .linear_gelu = LinearGelu (module .dense )
1238- elif self .module_device .type == "xpu" :
1239- self .linear_gelu = XPULinearGelu (module .dense )
1132+ self .linear_gelu = LinearGelu (module .dense )
12401133
12411134 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
12421135 if hasattr (self , "linear_gelu" ):
0 commit comments