@@ -229,7 +229,7 @@ def _llama_model_forward(
229229 input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
230230 seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
231231 query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
232- max_input_lens = input_lens .max (). item ()
232+ max_input_lens = input_lens .max ()
233233
234234 if past_key_values_length == 0 and past_key_values is not None :
235235 # first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -357,7 +357,7 @@ def _falcon_model_forward(
357357 input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
358358 seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
359359 query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
360- max_input_lens = input_lens .max (). item ()
360+ max_input_lens = input_lens .max ()
361361
362362 if past_key_values_length == 0 and past_key_values is not None :
363363 # first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -499,7 +499,7 @@ def _gpt2_model_forward(
499499 input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
500500 seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
501501 query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
502- max_input_lens = input_lens .max (). item ()
502+ max_input_lens = input_lens .max ()
503503
504504 if past_length == 0 and past_key_values is not None :
505505 # first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -635,7 +635,7 @@ def _qwen2_model_forward(
635635 input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
636636 seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
637637 query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
638- max_input_lens = input_lens .max (). item ()
638+ max_input_lens = input_lens .max ()
639639
640640 if past_key_values_length == 0 and past_key_values is not None :
641641 # first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -754,11 +754,11 @@ def attention_interface(
754754 if past_key_value is None :
755755 n_rep = query .shape [1 ] // key .shape [1 ]
756756 attn_output = torch .nn .functional .scaled_dot_product_attention (
757- query .reshape (input_lens .shape [0 ], input_lens .max (). item () , - 1 , query .shape [- 1 ]).transpose (1 , 2 ),
758- key .reshape (input_lens .shape [0 ], input_lens .max (). item () , - 1 , key .shape [- 1 ])
757+ query .reshape (input_lens .shape [0 ], input_lens .max (), - 1 , query .shape [- 1 ]).transpose (1 , 2 ),
758+ key .reshape (input_lens .shape [0 ], input_lens .max (), - 1 , key .shape [- 1 ])
759759 .transpose (1 , 2 )
760760 .repeat_interleave (n_rep , 1 ),
761- value .reshape (input_lens .shape [0 ], input_lens .max (). item () , - 1 , value .shape [- 1 ])
761+ value .reshape (input_lens .shape [0 ], input_lens .max (), - 1 , value .shape [- 1 ])
762762 .transpose (1 , 2 )
763763 .repeat_interleave (n_rep , 1 ),
764764 attn_mask = attention_mask ,
@@ -885,13 +885,11 @@ def __init__(self, module, device, config) -> None:
885885 self .q_slice = self .q_proj .weight .shape [0 ]
886886 self .k_slice = self .q_slice + self .k_proj .weight .shape [0 ]
887887 self .v_slice = self .k_slice + self .v_proj .weight .shape [0 ]
888- if self . module_device . type == "cpu" :
889- if module . o_proj . __class__ . __name__ not in [ "LinearAllreduce" ] :
888+ if not config . compile and module . o_proj . __class__ . __name__ not in [ "LinearAllreduce" ] :
889+ if self . module_device . type == "cpu" :
890890 self .mha_linear_add = LinearAdd (module .o_proj )
891-
892891 elif self .module_device .type == "xpu" :
893- if module .o_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
894- self .mha_linear_add = XPULinearAdd (module .o_proj )
892+ self .mha_linear_add = XPULinearAdd (module .o_proj )
895893
896894 def qkv_gemm (self , hidden_states ):
897895 if hasattr (self , "concat_qkv" ):
@@ -935,7 +933,7 @@ class _IPEXGPT2Attention(_IPEXAttention):
935933 def __init__ (self , module , device , config ) -> None :
936934 super ().__init__ (module , device , config )
937935 _setattr_from_module (self , module )
938- if getattr (config , "quantization_config" , None ) is None :
936+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
939937 self .c_attn_linear = nn .Linear (self .c_attn .weight .shape [0 ], self .c_attn .weight .shape [1 ])
940938 self .c_attn_linear .weight = nn .Parameter (self .c_attn .weight .t ())
941939 self .c_attn_linear .bias = self .c_attn .bias
@@ -979,7 +977,7 @@ def __init__(self, module, device, config) -> None:
979977 _setattr_from_module (self , module )
980978 self .config = config
981979 self .module_device = device
982- if getattr (config , "quantization_config" , None ) is None :
980+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
983981 if self .module_device .type == "cpu" :
984982 # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
985983 if module .down_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
@@ -1012,7 +1010,7 @@ def __init__(self, module, device, config) -> None:
10121010 _setattr_from_module (self , module )
10131011 self .config = config
10141012 self .module_device = device
1015- if getattr (config , "quantization_config" , None ) is None :
1013+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
10161014 # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
10171015 if self .module_device .type == "cpu" :
10181016 self .linear_gelu = LinearGelu (module .dense_h_to_4h )
@@ -1052,7 +1050,7 @@ def __init__(self, module, device, config) -> None:
10521050 self .config = config
10531051 self .module_device = device
10541052
1055- if getattr (config , "quantization_config" , None ) is None :
1053+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
10561054 self .c_fc_linear = nn .Linear (self .c_fc .weight .shape [0 ], self .c_fc .weight .shape [1 ])
10571055 self .c_fc_linear .weight = nn .Parameter (self .c_fc .weight .t ())
10581056 self .c_fc_linear .bias = self .c_fc .bias
@@ -1061,11 +1059,8 @@ def __init__(self, module, device, config) -> None:
10611059 self .c_proj_linear .bias = self .c_proj .bias
10621060 if self .module_device .type == "cpu" :
10631061 self .linear_new_gelu = LinearNewGelu (self .c_fc_linear )
1064-
1065- if self .module_device .type == "cpu" :
10661062 if self .c_proj_linear not in ["LinearAllreduce" ]:
10671063 self .linear_add = LinearAdd (self .c_proj_linear )
1068-
10691064 elif self .module_device .type == "xpu" :
10701065 if self .c_proj_linear not in ["LinearAllreduce" ]:
10711066 self .linear_add = XPULinearAdd (self .c_proj_linear )
@@ -1237,7 +1232,7 @@ def __init__(self, module, device, config):
12371232 super ().__init__ ()
12381233 _setattr_from_module (self , module )
12391234 self .module_device = device
1240- if getattr (config , "quantization_config" , None ) is None :
1235+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
12411236 if self .module_device .type == "cpu" :
12421237 self .linear_gelu = LinearGelu (module .dense )
12431238 elif self .module_device .type == "xpu" :
0 commit comments