Skip to content

Commit da336a3

Browse files
authored
fix batch_infer pad_token & florence (#2725)
1 parent 07e16a9 commit da336a3

File tree

6 files changed

+33
-24
lines changed

6 files changed

+33
-24
lines changed

swift/llm/infer/infer_engine/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,5 @@ def prepare_generation_config(model_generation_config: GenerationConfig, request
153153

154154
if generation_config.eos_token_id is None:
155155
generation_config.eos_token_id = tokenizer.eos_token_id
156-
if generation_config.pad_token_id is None:
157-
generation_config.pad_token_id = tokenizer.pad_token_id
156+
generation_config.pad_token_id = tokenizer.pad_token_id
158157
return generation_config

swift/llm/model/model/microsoft.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from types import MethodType
44
from typing import Any, Dict
55

6+
from transformers import AutoConfig
7+
68
from swift.llm import TemplateType
79
from swift.utils import get_env_args
810
from ..constant import LLMModelType, MLLMModelType
@@ -55,9 +57,12 @@ def get_model_tokenizer_florence(model_dir: str,
5557
model_kwargs: Dict[str, Any],
5658
load_model: bool = True,
5759
**kwargs):
60+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
61+
model_config.vision_config.model_type = 'davit' # fix merge-lora
62+
if model_kwargs['device_map'] == 'auto':
63+
model_kwargs['device_map'] = 'cuda:0'
64+
kwargs['model_config'] = model_config
5865
with ignore_check_imports():
59-
if model_kwargs['device_map'] == 'auto':
60-
model_kwargs['device_map'] = 'cuda:0'
6166
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
6267

6368
if model is not None:

swift/llm/template/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,10 +607,9 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
607607
for key, _slice in zip(['prompt', 'answer'],
608608
[slice(0, total_len - answer_len),
609609
slice(total_len - answer_len, total_len)]):
610-
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list[_slice],
611-
loss_scale_list[_slice], inputs)
612-
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
613-
res_context_list, loss_scale_list)
610+
context_list, loss_scale = self._simplify_context_list(res_context_list[_slice],
611+
loss_scale_list[_slice], inputs)
612+
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(context_list, loss_scale)
614613
encoded[f'{key}_input_ids'] = input_ids
615614
if key == 'answer':
616615
encoded['labels'] = labels

swift/llm/template/template/microsoft.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414

1515
class FlorenceTemplate(Template):
16-
# loss_scale = 'last_round'
17-
# skip_prompt = False
16+
# If it's an encoder-decoder architecture, the default settings are
17+
# loss_scale: 'last_round' and skip_prompt: False.
1818
is_encoder_decoder = True
1919

2020
@staticmethod
@@ -51,28 +51,32 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
5151
labels = encoded['labels']
5252
if labels is not None:
5353
labels = [0] + labels
54-
pixel_values = processor.image_processor(
55-
images, return_tensors='pt')['pixel_values'].to(self.config.torch_dtype)
56-
encoded = {
57-
'input_ids': input_ids,
58-
'labels': labels,
59-
'pixel_values': pixel_values,
60-
}
54+
if images:
55+
pixel_values = processor.image_processor(
56+
images, return_tensors='pt')['pixel_values'].to(self.config.torch_dtype)
57+
encoded['pixel_values'] = pixel_values
58+
encoded['input_ids'] = input_ids
59+
encoded['labels'] = labels
6160
return encoded
6261

6362
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
6463
inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
65-
image_features = model._encode_image(inputs['pixel_values'])
66-
inputs_embeds, _ = model._merge_input_ids_with_image_features(image_features, inputs_embeds)
64+
pixel_values = inputs.get('pixel_values')
65+
if pixel_values is not None:
66+
image_features = model._encode_image(pixel_values)
67+
inputs_embeds, inputs['attention_mask'] = model._merge_input_ids_with_image_features(
68+
image_features, inputs_embeds)
6769
return {'inputs_embeds': inputs_embeds}
6870

6971
def decode(self, generate_ids: List[int], **kwargs) -> Any:
7072
response = super().decode(generate_ids, **kwargs)
7173
template_inputs = kwargs.get('template_inputs')
7274
images = template_inputs.images
75+
image_size = None
76+
if images:
77+
image_size = (images[0].width, images[0].height)
7378
return json.dumps(
74-
self.processor.post_process_generation(
75-
response, task=template_inputs.query, image_size=(images[0].width, images[0].height)))
79+
self.processor.post_process_generation(response, task=template_inputs.query, image_size=image_size))
7680

7781

7882
register_template(

swift/trainers/rlhf_trainer/rlhf_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self,
8585
if self.ref_model is not None:
8686
disable_dropout_in_model(self.ref_model)
8787

88-
self.is_encoder_decoder = args.is_encoder_decoder
88+
self.is_encoder_decoder = kwargs['template'].is_encoder_decoder
8989
self.aux_loss_enabled = getattr(model.config, 'output_router_logits', False)
9090
self._peft_has_been_casted_to_bf16 = False
9191
self.generate_during_eval = getattr(args, 'generate_during_eval', False)

tests/test_align/test_template/test_vision.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def test_llava_hf():
108108

109109
def test_florence():
110110
pt_engine = PtEngine('AI-ModelScope/Florence-2-base-ft')
111+
_infer_model(pt_engine, messages=[{'role': 'user', 'content': 'who are you?'}], images=[])
112+
111113
_infer_model(
112114
pt_engine,
113115
messages=[{
@@ -265,7 +267,7 @@ def test_molmoe():
265267
# test_pixtral()
266268
# test_llama_vision()
267269
# test_llava_hf()
268-
# test_florence()
270+
test_florence()
269271
# test_glm_edge_v()
270272
#
271273
# test_phi3_vision()
@@ -276,4 +278,4 @@ def test_molmoe():
276278
# test_qvq()
277279
# test_mplug_owl2()
278280
# test_molmo()
279-
test_molmoe()
281+
# test_molmoe()

0 commit comments

Comments
 (0)