|
9 | 9 | import transformers
|
10 | 10 | from packaging import version
|
11 | 11 | from torch import nn
|
| 12 | +from transformers.integrations import is_deepspeed_zero3_enabled |
12 | 13 |
|
13 | 14 | from swift.llm import get_packed_seq_params, to_device, to_float_dtype
|
14 | 15 | from swift.utils import get_env_args, is_deepspeed_enabled
|
@@ -592,7 +593,80 @@ def _get_new_tokens(i):
|
592 | 593 | return encoded
|
593 | 594 |
|
594 | 595 | def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
595 |
| - return Template._post_encode(self, model, inputs) |
| 596 | + if not self.is_training: |
| 597 | + return inputs |
| 598 | + |
| 599 | + input_ids = inputs['input_ids'] |
| 600 | + pixel_values = inputs.get('pixel_values') |
| 601 | + pixel_values_videos = inputs.get('pixel_values_videos') |
| 602 | + image_grid_thw = inputs.get('image_grid_thw') |
| 603 | + video_grid_thw = inputs.get('video_grid_thw') |
| 604 | + input_features = inputs.get('input_features') |
| 605 | + feature_attention_mask = inputs.get('feature_attention_mask') |
| 606 | + |
| 607 | + base_model = self.get_base_model(model) |
| 608 | + inputs_embeds = base_model.thinker.model.embed_tokens(input_ids) |
| 609 | + visual = model.thinker.visual |
| 610 | + dtype = visual.dtype |
| 611 | + thinker_config = model.config.thinker_config |
| 612 | + if pixel_values is None and pixel_values_videos is None: # plain-text |
| 613 | + if is_deepspeed_enabled(): |
| 614 | + from PIL import Image |
| 615 | + images = [Image.new('RGB', (32, 32), (0, 0, 0))] |
| 616 | + media_inputs = self.processor.image_processor(images=images, return_tensors='pt') |
| 617 | + device = input_ids.device |
| 618 | + media_inputs = to_device(media_inputs, device) |
| 619 | + pixel_values = media_inputs['pixel_values'].type(dtype) |
| 620 | + image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) |
| 621 | + inputs_embeds = inputs_embeds + image_embeds.mean() * 0. |
| 622 | + else: |
| 623 | + if pixel_values is None: |
| 624 | + pixel_values_mixed = pixel_values_videos |
| 625 | + grid_thw = video_grid_thw |
| 626 | + elif pixel_values_videos is None: |
| 627 | + pixel_values_mixed = pixel_values |
| 628 | + grid_thw = image_grid_thw |
| 629 | + else: |
| 630 | + pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) |
| 631 | + grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) |
| 632 | + pixel_values_mixed = pixel_values_mixed.type(dtype) |
| 633 | + mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw) |
| 634 | + if pixel_values is None: |
| 635 | + image_embeds = None |
| 636 | + video_embeds = mixed_embeds |
| 637 | + elif pixel_values_videos is None: |
| 638 | + image_embeds = mixed_embeds |
| 639 | + video_embeds = None |
| 640 | + else: |
| 641 | + merge_length = self.processor.image_processor.merge_size**2 |
| 642 | + image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() |
| 643 | + image_embeds = mixed_embeds[:image_tokens] |
| 644 | + video_embeds = mixed_embeds[image_tokens:] |
| 645 | + |
| 646 | + if image_embeds is not None: |
| 647 | + image_mask = (input_ids == thinker_config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) |
| 648 | + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| 649 | + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| 650 | + |
| 651 | + if video_embeds is not None: |
| 652 | + video_mask = (input_ids == thinker_config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) |
| 653 | + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| 654 | + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
| 655 | + |
| 656 | + if input_features is None: |
| 657 | + if is_deepspeed_enabled() and not is_deepspeed_zero3_enabled(): |
| 658 | + # Note: ZeRO-3 still results in hangs; for audio training, please use ZeRO-2. |
| 659 | + input_features = input_ids.new_zeros([1, 128, 128], dtype=dtype) |
| 660 | + feature_attention_mask = input_ids.new_ones([1, 128], dtype=torch.bool) |
| 661 | + audio_embeds = model.thinker.get_audio_features(input_features, feature_attention_mask) |
| 662 | + inputs_embeds = inputs_embeds + audio_embeds.mean() * 0. |
| 663 | + else: |
| 664 | + audio_embeds = model.thinker.get_audio_features(input_features, feature_attention_mask) |
| 665 | + audio_mask = (input_ids == thinker_config.audio_token_index).unsqueeze(-1).expand_as(inputs_embeds) |
| 666 | + audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| 667 | + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_embeds) |
| 668 | + |
| 669 | + return {'inputs_embeds': inputs_embeds} |
596 | 670 |
|
597 | 671 | def _get_position_ids(self, inputs: Dict[str, Any]):
|
598 | 672 | feature_attention_mask = inputs.get('feature_attention_mask')
|
|
0 commit comments