Skip to content

Commit ab2133b

Browse files
authored
compat transformers 4.56 (#5666)
1 parent bee6771 commit ab2133b

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

requirements/framework.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ sortedcontainers>=1.5.9
3232
tensorboard
3333
tiktoken
3434
tqdm
35-
transformers>=4.33,<4.56
35+
transformers>=4.33,<4.57
3636
transformers_stream_generator
3737
trl>=0.15,<0.21
3838
uvicorn

swift/llm/template/template/qwen.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,19 +311,22 @@ def _get_new_tokens(i):
311311
def forward_context(self, model, inputs):
312312
if 'real_position_ids' not in inputs:
313313
return super().forward_context(model, inputs)
314-
position_ids = inputs['position_ids']
314+
text_position_ids = inputs['position_ids']
315315
inputs['position_ids'] = inputs.pop('real_position_ids')
316-
transformers_ge_453 = version.parse(transformers.__version__) >= version.parse('4.53')
317-
if transformers_ge_453:
318-
inputs.update(get_packed_seq_params(position_ids))
316+
transformers_version = version.parse(transformers.__version__)
317+
if transformers_version >= version.parse('4.53'):
318+
if transformers_version >= version.parse('4.56'):
319+
inputs['position_ids'] = torch.concat([text_position_ids[None], inputs['position_ids']], dim=0)
320+
else:
321+
inputs.update(get_packed_seq_params(text_position_ids))
319322
return super().forward_context(model, inputs)
320323
if self.version == 'v2':
321324
from transformers.models.qwen2_vl import modeling_qwen2_vl as modeling_module
322325
elif self.version == 'v2_5':
323326
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as modeling_module
324327
elif self.version == 'omni':
325328
from transformers.models.qwen2_5_omni import modeling_qwen2_5_omni as modeling_module
326-
return self._patch_flash_attention_forward(modeling_module, position_ids)
329+
return self._patch_flash_attention_forward(modeling_module, text_position_ids)
327330

328331
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
329332
if not self.is_training:

0 commit comments

Comments
 (0)