Skip to content

Commit 2a8d138

Browse files
authored
[LLM]support QWenVL second part (PaddlePaddle#7808)
* add qwenvl second part * improve code & add comments * add qwenvl test * update qwenvl tests * remove print
1 parent 4a9c766 commit 2a8d138

File tree

4 files changed

+268
-11
lines changed

4 files changed

+268
-11
lines changed

llm/predictor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,11 +1395,16 @@ def create_predictor(
13951395
)
13961396
model.eval()
13971397
elif "qwen" in config.architectures[0].lower():
1398-
from paddlenlp.experimental.transformers import (
1399-
QWenForCausalLMInferenceModel,
1400-
)
1401-
1402-
model = QWenForCausalLMInferenceModel.from_pretrained(
1398+
if model_args.model_type == "qwen-img2txt":
1399+
# we use qwen for img2txt.
1400+
from paddlenlp.experimental.transformers import (
1401+
QWenForQWenVLInferenceModel as QWenInferenceModel,
1402+
)
1403+
else:
1404+
from paddlenlp.experimental.transformers import (
1405+
QWenForCausalLMInferenceModel as QWenInferenceModel,
1406+
)
1407+
model = QWenInferenceModel.from_pretrained(
14031408
predictor_args.model_name_or_path,
14041409
config=config,
14051410
dtype=predictor_args.dtype,

paddlenlp/experimental/transformers/qwen/modeling.py

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040
from paddlenlp.transformers.qwen.modeling import QWenLMHead, QWenPretrainingCriterion
4141

42-
__all__ = ["QWenForCausalLMInferenceModel"]
42+
__all__ = ["QWenForCausalLMInferenceModel", "QWenForQWenVLInferenceModel"]
4343

4444

4545
class FusedQWenRMSNorm(nn.Layer):
@@ -244,6 +244,19 @@ def remove_padding(self, input_ids, seq_lens_this_time):
244244
)
245245
return ids_remove_padding, padding_offset, cum_offsets
246246

247+
# This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py,
248+
# it is used to generate fake input_ids according to inputs_embeds length.
249+
@staticmethod
250+
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
251+
batch_size = 1
252+
seq_len = 1
253+
if bos_token_id is None:
254+
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
255+
if encoder_output is not None:
256+
batch_size = encoder_output.shape[0]
257+
seq_len = encoder_output.shape[1]
258+
return paddle.full([batch_size, seq_len], bos_token_id, dtype="int64")
259+
247260
def forward(
248261
self,
249262
input_ids=None,
@@ -270,17 +283,21 @@ def forward(
270283
elif input_ids is None and inputs_embeds is None:
271284
raise ValueError("You have to specify either input_ids or inputs_embeds")
272285

286+
# generate a fake input_ids according to inputs_embeds
287+
# this is usually occurred in img2txt multimodal model when first enter into this forward function.
288+
if input_ids is None and inputs_embeds is not None:
289+
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
290+
if inputs_embeds is not None:
291+
batch, seq_len, hidden_dim = inputs_embeds.shape
292+
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])
293+
273294
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
274295
output_hidden_states = (
275296
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
276297
)
277298
use_cache = use_cache if use_cache is not None else self.config.use_cache
278299
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
279300

280-
if inputs_embeds is not None:
281-
batch, seq_len, hidden_dim = inputs_embeds.shape
282-
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])
283-
284301
if past_key_values is None:
285302
past_key_values = tuple([None] * self.config.num_hidden_layers)
286303

@@ -502,3 +519,122 @@ def set_state_dict(self, state_dict):
502519
lm_head_weight = paddle.to_tensor(state_dict["lm_head.weight"], dtype=self.lm_head.weight.dtype)
503520
self.lm_head.weight.set_value(lm_head_weight)
504521
self.qwen.set_state_dict({k: state_dict[k] for k in state_dict.keys()})
522+
523+
524+
class QWenForQWenVLInferenceModel(QWenForCausalLMInferenceModel):
525+
"""
526+
This class is 99% like QWenForCausalLMInferenceModel.
527+
Used only for QWenVL's second part.
528+
"""
529+
530+
# This function corresponds to QWenVL's second part, only used for QWenVL.
531+
@paddle.no_grad()
532+
def generate_text_with_image_features(
533+
self,
534+
input_ids: paddle.Tensor,
535+
image_features: paddle.Tensor,
536+
img_pos: paddle.Tensor,
537+
attention_mask: paddle.Tensor,
538+
position_ids=None,
539+
penalty_score=None,
540+
frequency_score=None,
541+
presence_score=None,
542+
min_length=None,
543+
max_length=None,
544+
temperature=None,
545+
top_p=None,
546+
eos_token_id=None,
547+
seq_len_encoder=None,
548+
seq_len_decoder=None,
549+
step_idx=None,
550+
stop_flags=None,
551+
tgt_ids=None,
552+
tgt_pos=None,
553+
tgt_generation_mask=None,
554+
pre_ids=None,
555+
stop_nums=None,
556+
cache_kvs=[],
557+
inputs_embeds=None,
558+
**generate_kwargs
559+
) -> paddle.Tensor:
560+
inputs_embeds = self.qwen.wte(input_ids)
561+
inputs_embeds_dtype = inputs_embeds.dtype
562+
if inputs_embeds_dtype != paddle.float32:
563+
inputs_embeds = paddle.cast(inputs_embeds, paddle.float32)
564+
image_features = paddle.cast(image_features, paddle.float32)
565+
566+
for idx, (i, image_start_idx, image_end_idx) in enumerate(img_pos):
567+
index = paddle.arange(image_start_idx + 1, image_end_idx).unsqueeze(-1)
568+
inputs_embeds[i] = paddle.scatter(inputs_embeds[i], index, image_features[idx])
569+
570+
if inputs_embeds_dtype != paddle.float32:
571+
inputs_embeds = paddle.cast(inputs_embeds, inputs_embeds_dtype)
572+
573+
outputs = self.generate(
574+
inputs_embeds=inputs_embeds,
575+
attention_mask=attention_mask,
576+
position_ids=position_ids,
577+
penalty_score=penalty_score,
578+
frequency_score=frequency_score,
579+
presence_score=presence_score,
580+
min_length=min_length,
581+
max_length=max_length,
582+
temperature=temperature,
583+
top_p=top_p,
584+
eos_token_id=eos_token_id,
585+
seq_len_encoder=seq_len_encoder,
586+
seq_len_decoder=seq_len_decoder,
587+
step_idx=step_idx,
588+
stop_flags=stop_flags,
589+
tgt_ids=tgt_ids,
590+
tgt_pos=tgt_pos,
591+
tgt_generation_mask=tgt_generation_mask,
592+
pre_ids=pre_ids,
593+
stop_nums=stop_nums,
594+
cache_kvs=cache_kvs,
595+
)
596+
return outputs
597+
598+
# rewrite to_static function in generation_utils.py
599+
def to_static(self, output_path: str, config: dict):
600+
dtype = config.get("dtype", paddle.get_default_dtype())
601+
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
602+
input_spec = [
603+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), # input_ids
604+
paddle.static.InputSpec(
605+
shape=[None, None, None], dtype="float32", name="image_features"
606+
), # image_features
607+
paddle.static.InputSpec(shape=[None, 3], dtype="int64", name="img_pos"), # img_pos
608+
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
609+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
610+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
611+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
612+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
613+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
614+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
615+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
616+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
617+
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
618+
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
619+
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
620+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
621+
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
622+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
623+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
624+
paddle.static.InputSpec(
625+
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
626+
), # tgt_generation_mask
627+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
628+
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
629+
[
630+
paddle.static.InputSpec(
631+
shape=shape,
632+
dtype=dtype,
633+
name="cache_kvs_{}".format(i),
634+
)
635+
for i, shape in enumerate(cache_kvs_shapes)
636+
], # cache_kvs
637+
]
638+
639+
model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
640+
paddle.jit.save(model, output_path, skip_prune_program=True)

paddlenlp/transformers/qwen/configuration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(
4747
tensor_parallel_output=True,
4848
no_bias=True,
4949
tie_word_embeddings=False,
50+
pad_token_id=0,
51+
bos_token_id=1,
52+
eos_token_id=2,
5053
**kwargs,
5154
):
5255
self.vocab_size = vocab_size
@@ -72,4 +75,10 @@ def __init__(
7275
self.use_fused_rope = use_fused_rope
7376
self.no_bias = no_bias
7477

75-
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
78+
super().__init__(
79+
pad_token_id=pad_token_id,
80+
bos_token_id=bos_token_id,
81+
eos_token_id=eos_token_id,
82+
tie_word_embeddings=tie_word_embeddings,
83+
**kwargs,
84+
)

tests/llm/test_predictor.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import pytest
2121
from parameterized import parameterized_class
2222

23+
from paddlenlp.experimental.transformers import QWenForQWenVLInferenceModel
2324
from paddlenlp.transformers import ( # ChatGLMForCausalLM,
25+
AutoConfig,
2426
AutoTokenizer,
2527
BloomForCausalLM,
2628
ChatGLMForCausalLM,
@@ -325,3 +327,108 @@ def test_predictor(self):
325327

326328
self.assertGreaterEqual(full_match / len(result_0), 0.25)
327329
self.assertGreaterEqual(count / len(result_0), 0.4)
330+
331+
332+
class QWenVLTest(LLMTest, unittest.TestCase):
333+
config_path: str = "./tests/fixtures/llm/predictor.yaml"
334+
model_name_or_path: str = "__internal_testing__/tiny-fused-qwen"
335+
model_class = QWenForCausalLM
336+
337+
def setUp(self) -> None:
338+
super().setUp()
339+
paddle.set_default_dtype("float32")
340+
self.model_class.from_pretrained(self.model_name_or_path, dtype="float16").save_pretrained(self.output_dir)
341+
AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir)
342+
343+
def test_forward(self):
344+
self.disable_static()
345+
config = AutoConfig.from_pretrained(self.output_dir)
346+
config.quant_type = None
347+
config.weight_only_quant_bits = None
348+
349+
paddle.set_default_dtype("float16")
350+
model = QWenForQWenVLInferenceModel.from_pretrained(self.output_dir, config=config, dtype="float16")
351+
352+
batch = 1
353+
seq = 31
354+
max_len = 50
355+
dtype = "float16"
356+
input_ids = paddle.randint(0, 100, [batch, seq], dtype="int64")
357+
image_features = paddle.randn([batch, 16, config.hidden_size], dtype="float16")
358+
tgt_generation_mask = paddle.full([batch, 1, 1, max_len], 1, dtype=dtype)
359+
img_pos = paddle.to_tensor([[0, 4, 21]], dtype="int64")
360+
attention_mask = paddle.full([batch, 1, max_len, max_len], 0, dtype=dtype)
361+
attention_mask[:, 0, :seq, :seq] = paddle.tril(paddle.ones(shape=(seq, seq), dtype=dtype))
362+
position_ids = paddle.full([batch, seq], 0, dtype="int64")
363+
for i in range(batch):
364+
position_ids[i, :] = paddle.to_tensor([i for i in range(seq)], dtype="int64")
365+
366+
inputs = [
367+
input_ids, # input_ids
368+
image_features, # image_features
369+
img_pos, # img_pos
370+
attention_mask, # attention_mask
371+
position_ids, # position_ids
372+
paddle.full([batch, 1], 1.0, dtype="float32"), # penalty_score
373+
paddle.full([batch, 1], 0.0, dtype="float32"), # frequency_score,
374+
paddle.full([batch, 1], 0.0, dtype="float32"), # presence_score,
375+
paddle.full([batch, 1], 1, dtype="int64"), # min_length,
376+
paddle.full([batch, 1], max_len - seq, dtype="int64"), # max_length,
377+
paddle.full([batch, 1], 1.0, dtype="float32"), # temperature,
378+
paddle.full([batch, 1], 0.0, dtype="float32"), # top_p,
379+
paddle.full([1], 151643, dtype="int64"), # eos_token_id,
380+
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_encoder,
381+
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_decoder,
382+
paddle.full([batch, 1], 0, dtype="int64"), # step_idx,
383+
paddle.full([batch, 1], False, dtype="bool"), # stop_flags,
384+
paddle.full([batch, 1], -123, dtype="int64"), # tgt_ids can be be initialized arbitrarily
385+
paddle.full([batch, 1], seq - 1, dtype="int64"), # tgt_pos,
386+
tgt_generation_mask, # tgt_generation_mask,
387+
paddle.full([batch, max_len], -100, dtype="int64"), # pre_ids, can be initialized arbitrarily
388+
paddle.full([1], batch, dtype="int64"), # stop_nums, be batch
389+
]
390+
for i in range(config.num_hidden_layers):
391+
tmp = paddle.rand(shape=[2, batch, 1, max_len, 64], dtype=dtype)
392+
inputs.append(tmp)
393+
394+
model.eval()
395+
model.generate_text_with_image_features(
396+
input_ids=inputs[0],
397+
image_features=inputs[1],
398+
img_pos=inputs[2],
399+
attention_mask=inputs[3],
400+
position_ids=inputs[4],
401+
penalty_score=inputs[5],
402+
frequency_score=inputs[6],
403+
presence_score=inputs[7],
404+
min_length=inputs[8],
405+
max_length=inputs[9],
406+
temperature=inputs[10],
407+
top_p=inputs[11],
408+
eos_token_id=inputs[12],
409+
seq_len_encoder=inputs[13],
410+
seq_len_decoder=inputs[14],
411+
step_idx=inputs[15],
412+
stop_flags=inputs[16],
413+
tgt_ids=inputs[17],
414+
tgt_pos=inputs[18],
415+
tgt_generation_mask=inputs[19],
416+
pre_ids=inputs[20],
417+
stop_nums=inputs[21],
418+
cache_kvs=inputs[22:],
419+
)
420+
421+
def test_export(self):
422+
self.disable_static()
423+
config = load_test_config(self.config_path, "inference-to-static")
424+
config["model_name_or_path"] = self.model_name_or_path
425+
config["output_path"] = self.output_dir
426+
config["dtype"] = "float16"
427+
config["inference_model"] = True
428+
config["model_prefix"] = "qwen"
429+
config["model_type"] = "qwen-img2txt"
430+
431+
with argv_context_guard(config):
432+
from export_model import main
433+
434+
main()

0 commit comments

Comments
 (0)