@@ -144,19 +144,24 @@ def _llama_model_forward(
144144 seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
145145 query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
146146 max_input_lens = input_lens .max ()
147+ cos = position_embeddings [0 ]
148+ sin = position_embeddings [1 ]
147149
148150 if past_key_values_length == 0 and past_key_values is not None :
149151 # first token, remove the padding from hidden_states, varlen do not accept attention mask
150152 hidden_states_copy = hidden_states
151153 index = attention_mask .view (- 1 ) != 0
152154 hidden_states = (hidden_states .view (- 1 , hidden_states .shape [- 1 ]))[index ]
153- cos = position_embeddings [0 ]
154- sin = position_embeddings [1 ]
155155 cos = (cos .reshape (- 1 , cos .shape [- 1 ]))[index ]
156156 sin = (sin .reshape (- 1 , sin .shape [- 1 ]))[index ]
157157 position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
158158 else :
159159 hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
160+ # TODO: remove this WA after IPEX 2.7
161+ if device .type == "xpu" :
162+ cos = cos .reshape (- 1 , cos .shape [- 1 ])
163+ sin = sin .reshape (- 1 , sin .shape [- 1 ])
164+ position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
160165
161166 if past_key_values is None :
162167 attention_mask = _prepare_4d_causal_attention_mask_for_sdpa (
@@ -272,19 +277,24 @@ def _falcon_model_forward(
272277 seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
273278 query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
274279 max_input_lens = input_lens .max ()
280+ cos = position_embeddings [0 ]
281+ sin = position_embeddings [1 ]
275282
276283 if past_key_values_length == 0 and past_key_values is not None :
277284 # first token, remove the padding from hidden_states, varlen do not accept attention mask
278285 hidden_states_copy = hidden_states
279286 index = attention_mask .view (- 1 ) != 0
280287 hidden_states = (hidden_states .view (- 1 , hidden_states .shape [- 1 ]))[index ]
281- cos = position_embeddings [0 ]
282- sin = position_embeddings [1 ]
283288 cos = (cos .reshape (- 1 , cos .shape [- 1 ]))[index ]
284289 sin = (sin .reshape (- 1 , sin .shape [- 1 ]))[index ]
285290 position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
286291 else :
287292 hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
293+ # TODO: remove this WA after IPEX 2.7
294+ if device .type == "xpu" :
295+ cos = cos .reshape (- 1 , cos .shape [- 1 ])
296+ sin = sin .reshape (- 1 , sin .shape [- 1 ])
297+ position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
288298
289299 if past_key_values is None :
290300 attention_mask = _prepare_4d_causal_attention_mask_for_sdpa (
@@ -550,19 +560,24 @@ def _qwen2_model_forward(
550560 seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
551561 query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
552562 max_input_lens = input_lens .max ()
563+ cos = position_embeddings [0 ]
564+ sin = position_embeddings [1 ]
553565
554566 if past_key_values_length == 0 and past_key_values is not None :
555567 # first token, remove the padding from hidden_states, varlen do not accept attention mask
556568 hidden_states_copy = hidden_states
557569 index = attention_mask .view (- 1 ) != 0
558570 hidden_states = (hidden_states .view (- 1 , hidden_states .shape [- 1 ]))[index ]
559- cos = position_embeddings [0 ]
560- sin = position_embeddings [1 ]
561571 cos = (cos .reshape (- 1 , cos .shape [- 1 ]))[index ]
562572 sin = (sin .reshape (- 1 , sin .shape [- 1 ]))[index ]
563573 position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
564574 else :
565575 hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
576+ # TODO: remove this WA after IPEX 2.7
577+ if device .type == "xpu" :
578+ cos = cos .reshape (- 1 , cos .shape [- 1 ])
579+ sin = sin .reshape (- 1 , sin .shape [- 1 ])
580+ position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
566581
567582 if past_key_values is None :
568583 attention_mask = causal_mask
0 commit comments