Skip to content

Commit 9f67221

Browse files
authored
[template] support glm4_5v packing mixed_data (#5674)
1 parent 0e0fc6a commit 9f67221

File tree

11 files changed

+171
-318
lines changed

11 files changed

+171
-318
lines changed

swift/llm/model/model/glm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,10 @@ def get_model_tokenizer_glm_edge_v(model_dir: str, *args, **kwargs):
443443
def get_model_tokenizer_glm4_5v(*args, **kwargs):
444444
from transformers import Glm4vMoeForConditionalGeneration
445445
kwargs['automodel_class'] = kwargs['automodel_class'] or Glm4vMoeForConditionalGeneration
446-
return get_model_tokenizer_multimodal(*args, **kwargs)
446+
model, processor = get_model_tokenizer_multimodal(*args, **kwargs)
447+
if model is not None:
448+
patch_get_input_embeddings(model.visual, 'patch_embed')
449+
return model, processor
447450

448451

449452
register_model(

swift/llm/model/model/qwen.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from swift.utils import get_device_count, get_dist_setting, get_env_args, get_logger
1212
from ..constant import LLMModelType, MLLMModelType, RMModelType
1313
from ..model_arch import ModelArch
14-
from ..patcher import patch_fixed_device, patch_get_input_embeddings, patch_output_clone, patch_output_to_input_device
14+
from ..patcher import patch_fixed_device, patch_get_input_embeddings, patch_output_clone
1515
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, get_model_tokenizer_reward_model,
1616
get_model_tokenizer_with_flash_attn, register_model)
1717
from ..utils import AttnImpl, ModelInfo, use_submodel_func
@@ -654,12 +654,6 @@ def get_model_tokenizer_qwen2_vl(*args, **kwargs):
654654
model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs)
655655
if model is not None:
656656
base_model = model.model if 'AWQ' in model.__class__.__name__ else model
657-
if hasattr(base_model.model, 'embed_tokens'):
658-
embed_tokens = base_model.model.embed_tokens
659-
else:
660-
embed_tokens = base_model.model.language_model.embed_tokens
661-
patch_output_clone(embed_tokens)
662-
patch_output_to_input_device(embed_tokens)
663657
patch_get_input_embeddings(base_model.visual, 'patch_embed')
664658

665659
from qwen_vl_utils import vision_process

swift/llm/template/base.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformers.integrations import is_deepspeed_zero3_enabled
2222
from transformers.utils import strtobool
2323

24+
from swift.llm import to_device
2425
from swift.utils import get_env_args, get_logger
2526
from ..utils import Processor, ProcessorMixin
2627
from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs
@@ -1349,13 +1350,12 @@ def post_process_generate_response(self, response: str, inputs: StdTemplateInput
13491350
return response
13501351

13511352
def pre_forward_hook(self, model: nn.Module, args, kwargs):
1352-
from swift.llm import to_device
13531353
old_kwargs = to_device(kwargs, model.device)
13541354
kwargs = to_device(self._post_encode(model, old_kwargs), model.device)
13551355
for k, v in old_kwargs.items():
13561356
if k in {
13571357
'input_ids', 'attention_mask', 'labels', 'position_ids', 'output_hidden_states', 'logits_to_keep',
1358-
'cumulative_seqlens_q', 'cumulative_seqlens_k', 'max_length_q', 'max_length_k'
1358+
'max_length_q', 'max_length_k', 'cu_seq_lens_q', 'cu_seq_lens_k'
13591359
} and k not in kwargs:
13601360
kwargs[k] = v
13611361
if 'inputs_embeds' in kwargs:
@@ -1629,7 +1629,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
16291629
res = {}
16301630
if self.padding_free:
16311631
assert len(batch) == 1, f'batch: {batch}'
1632-
for k in ['input_ids', 'labels', 'position_ids', 'loss_scale', 'channel', 'real_position_ids']:
1632+
for k in ['input_ids', 'labels', 'position_ids', 'loss_scale', 'channel']:
16331633
v = batch[0].get(k)
16341634
if v is not None:
16351635
res[k] = v if k == 'channel' else [v]
@@ -1651,10 +1651,15 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
16511651
res[key] = val
16521652

16531653
keys = [
1654-
'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids',
1655-
'real_position_ids'
1654+
'input_ids',
1655+
'inputs_embeds',
1656+
'attention_mask',
1657+
'labels',
1658+
'loss_scale',
1659+
'position_ids',
1660+
'token_type_ids',
16561661
]
1657-
pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0, 0.]
1662+
pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0]
16581663
# Convert to tensor and remove unnecessary dimensions.
16591664
seq_lens = None
16601665
for key in keys:
@@ -1681,16 +1686,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
16811686
if self.padding_free:
16821687
cp_size = self.sequence_parallel_size
16831688
if cp_size > 1:
1684-
for key in ['position_ids', 'real_position_ids']:
1685-
if key not in res:
1686-
continue
1687-
padding_len = padding_to - seq_lens[0]
1688-
position_ids = res[key][0]
1689-
extended_position_ids = torch.arange(cp_size * 2).repeat(padding_len // (cp_size * 2))
1690-
if position_ids.ndim == 3: # compat mrope
1691-
extended_position_ids = extended_position_ids[None,
1692-
None, :].expand(position_ids.shape[0], 1, -1)
1693-
res[key] = [torch.concat([position_ids, extended_position_ids], dim=-1)]
1689+
padding_len = padding_to - seq_lens[0]
1690+
position_ids = res['position_ids'][0]
1691+
extended_position_ids = torch.arange(cp_size * 2).repeat(padding_len // (cp_size * 2))
1692+
if position_ids.ndim == 3: # compat mrope
1693+
extended_position_ids = extended_position_ids[None,
1694+
None, :].expand(position_ids.shape[0], 1, -1)
1695+
res['position_ids'] = [torch.concat([position_ids, extended_position_ids], dim=-1)]
16941696
else:
16951697
seq_len = max(seq_lens) if padding_to is None else padding_to
16961698
res['attention_mask'] = torch.tril(torch.ones(
@@ -1704,13 +1706,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
17041706
continue
17051707
if self.use_megatron and not self.padding_free and key == 'attention_mask':
17061708
continue
1707-
if padding_to is not None and not (self.padding_free and key in {'position_ids', 'real_position_ids'}
1709+
if padding_to is not None and not (self.padding_free and key == 'position_ids'
17081710
and self.sequence_parallel_size > 1):
17091711
padding_len = padding_to - seq_lens[0]
17101712
if padding_len > 0:
17111713
res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0),
17121714
'constant', pad_value)
1713-
if key == 'real_position_ids':
1715+
if key == 'position_ids' and res[key][0].ndim == 3:
17141716
res[key] = torch.concat(res[key], dim=-1)
17151717
else:
17161718
res[key] = self._pad_sequence(res[key], pad_value)
@@ -1951,3 +1953,53 @@ def _flash_attention_forward(*args, **kwargs):
19511953
yield
19521954
finally:
19531955
modeling_module._flash_attention_forward = _origin_flash_attention_forward
1956+
1957+
@staticmethod
1958+
def _get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config):
1959+
input_ids = inputs['input_ids']
1960+
pixel_values = inputs.get('pixel_values')
1961+
pixel_values_videos = inputs.get('pixel_values_videos')
1962+
image_grid_thw = inputs.get('image_grid_thw')
1963+
video_grid_thw = inputs.get('video_grid_thw')
1964+
dtype = visual.dtype
1965+
if pixel_values is None and pixel_values_videos is None: # plain-text
1966+
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
1967+
media_inputs = processor.image_processor(images=images, return_tensors='pt')
1968+
media_inputs = to_device(media_inputs, input_ids.device)
1969+
pixel_values = media_inputs['pixel_values'].type(dtype)
1970+
image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
1971+
inputs_embeds = inputs_embeds + image_embeds.mean() * 0.
1972+
else:
1973+
if pixel_values is None:
1974+
pixel_values_mixed = pixel_values_videos
1975+
grid_thw = video_grid_thw
1976+
elif pixel_values_videos is None:
1977+
pixel_values_mixed = pixel_values
1978+
grid_thw = image_grid_thw
1979+
else:
1980+
pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0)
1981+
grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
1982+
pixel_values_mixed = pixel_values_mixed.type(dtype)
1983+
mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw)
1984+
if pixel_values is None:
1985+
image_embeds = None
1986+
video_embeds = mixed_embeds
1987+
elif pixel_values_videos is None:
1988+
image_embeds = mixed_embeds
1989+
video_embeds = None
1990+
else:
1991+
merge_length = processor.image_processor.merge_size**2
1992+
image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum()
1993+
image_embeds = mixed_embeds[:image_tokens]
1994+
video_embeds = mixed_embeds[image_tokens:]
1995+
1996+
if image_embeds is not None:
1997+
image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
1998+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
1999+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
2000+
2001+
if video_embeds is not None:
2002+
video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
2003+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
2004+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
2005+
return inputs_embeds

swift/llm/template/template/glm.py

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
import torch
66

7-
from swift.llm import to_device
8-
from swift.utils import is_deepspeed_enabled
7+
from swift.llm import get_packed_seq_params
98
from ..base import Template
109
from ..constant import LLMTemplateType, MLLMTemplateType
1110
from ..register import TemplateMeta, register_template
@@ -234,57 +233,8 @@ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
234233
if not self.is_training:
235234
return inputs
236235
input_ids = inputs['input_ids']
237-
pixel_values = inputs.get('pixel_values')
238-
pixel_values_videos = inputs.get('pixel_values_videos')
239-
image_grid_thw = inputs.get('image_grid_thw')
240-
video_grid_thw = inputs.get('video_grid_thw')
241-
242236
inputs_embeds = model.get_input_embeddings()(input_ids)
243-
dtype = model.visual.dtype
244-
if pixel_values is None and pixel_values_videos is None: # plain-text
245-
if is_deepspeed_enabled():
246-
from PIL import Image
247-
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
248-
media_inputs = self.processor.image_processor(images=images, return_tensors='pt')
249-
device = input_ids.device
250-
media_inputs = to_device(media_inputs, device)
251-
pixel_values = media_inputs['pixel_values'].type(dtype)
252-
image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
253-
inputs_embeds += image_embeds.mean() * 0.
254-
else:
255-
if pixel_values is None:
256-
pixel_values_mixed = pixel_values_videos
257-
grid_thw = video_grid_thw
258-
elif pixel_values_videos is None:
259-
pixel_values_mixed = pixel_values
260-
grid_thw = image_grid_thw
261-
else:
262-
pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0)
263-
grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
264-
pixel_values_mixed = pixel_values_mixed.type(dtype)
265-
mixed_embeds = model.visual(pixel_values_mixed, grid_thw=grid_thw)
266-
if pixel_values is None:
267-
image_embeds = None
268-
video_embeds = mixed_embeds
269-
elif pixel_values_videos is None:
270-
image_embeds = mixed_embeds
271-
video_embeds = None
272-
else:
273-
merge_length = self.processor.image_processor.merge_size**2
274-
image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum()
275-
image_embeds = mixed_embeds[:image_tokens]
276-
video_embeds = mixed_embeds[image_tokens:]
277-
278-
if image_embeds is not None:
279-
image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
280-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
281-
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
282-
283-
if video_embeds is not None:
284-
video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
285-
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
286-
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
287-
237+
inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config)
288238
return {'inputs_embeds': inputs_embeds}
289239

290240

@@ -314,6 +264,8 @@ def _jinja_encode(self, inputs: StdTemplateInputs):
314264

315265
class GLM4_5VTemplate(Template):
316266
placeholder_tokens = ['<|image|>']
267+
support_padding_free = True # https://github.com/huggingface/transformers/issues/39685
268+
use_model = True
317269

318270
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
319271
inputs: StdTemplateInputs) -> List[Context]:
@@ -348,6 +300,49 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
348300
encoded['input_ids'] = input_ids
349301
return encoded
350302

303+
def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]:
304+
position_ids = []
305+
for r in row:
306+
r = r.copy()
307+
r['input_ids'] = torch.tensor(r['input_ids'])[None]
308+
position_ids.append(self._get_position_ids(r))
309+
packed = super().packing_row(row)
310+
packed['position_ids'] = torch.concat(position_ids, dim=-1)
311+
return packed
312+
313+
def _get_position_ids(self, inputs: Dict[str, Any]):
314+
base_model = self.get_base_model(self.model)
315+
position_ids, _ = base_model.model.get_rope_index(
316+
inputs['input_ids'],
317+
inputs.get('image_grid_thw'),
318+
inputs.get('video_grid_thw'),
319+
attention_mask=inputs.get('attention_mask'))
320+
text_position_ids = torch.arange(inputs['input_ids'].shape[-1])
321+
return torch.concat([text_position_ids[None, None], position_ids], dim=0)
322+
323+
def forward_context(self, model, inputs):
324+
position_ids = inputs['position_ids']
325+
inputs['position_ids'] = position_ids[1:]
326+
inputs['text_position_ids'] = position_ids[0]
327+
# https://github.com/huggingface/transformers/pull/40194
328+
inputs.update(get_packed_seq_params(inputs['text_position_ids']))
329+
return super().forward_context(model, inputs)
330+
331+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
332+
if not self.is_training:
333+
return inputs
334+
input_ids = inputs['input_ids']
335+
base_model = self.get_base_model(model)
336+
inputs_embeds = base_model.model.language_model.embed_tokens(input_ids)
337+
inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config)
338+
return {'inputs_embeds': inputs_embeds}
339+
340+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
341+
res = super()._data_collator(batch, padding_to=padding_to)
342+
if not self.padding_free and self.is_training:
343+
res['position_ids'] = self._get_position_ids(res)
344+
return res
345+
351346

352347
register_template(GLM4_0414TemplateMeta(MLLMTemplateType.glm4_5v, template_cls=GLM4_5VTemplate))
353348

0 commit comments

Comments
 (0)