39
39
)
40
40
from paddlenlp .transformers .qwen .modeling import QWenLMHead , QWenPretrainingCriterion
41
41
42
- __all__ = ["QWenForCausalLMInferenceModel" ]
42
+ __all__ = ["QWenForCausalLMInferenceModel" , "QWenForQWenVLInferenceModel" ]
43
43
44
44
45
45
class FusedQWenRMSNorm (nn .Layer ):
@@ -244,6 +244,19 @@ def remove_padding(self, input_ids, seq_lens_this_time):
244
244
)
245
245
return ids_remove_padding , padding_offset , cum_offsets
246
246
247
+ # This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py,
248
+ # it is used to generate fake input_ids according to inputs_embeds length.
249
+ @staticmethod
250
+ def prepare_input_ids_for_generation (bos_token_id , encoder_output = None ):
251
+ batch_size = 1
252
+ seq_len = 1
253
+ if bos_token_id is None :
254
+ raise ValueError ("`bos_token_id` should be defined when no " "`input_ids` are provided." )
255
+ if encoder_output is not None :
256
+ batch_size = encoder_output .shape [0 ]
257
+ seq_len = encoder_output .shape [1 ]
258
+ return paddle .full ([batch_size , seq_len ], bos_token_id , dtype = "int64" )
259
+
247
260
def forward (
248
261
self ,
249
262
input_ids = None ,
@@ -270,17 +283,21 @@ def forward(
270
283
elif input_ids is None and inputs_embeds is None :
271
284
raise ValueError ("You have to specify either input_ids or inputs_embeds" )
272
285
286
+ # generate a fake input_ids according to inputs_embeds
287
+ # this is usually occurred in img2txt multimodal model when first enter into this forward function.
288
+ if input_ids is None and inputs_embeds is not None :
289
+ input_ids = self .prepare_input_ids_for_generation (self .config .bos_token_id , inputs_embeds )
290
+ if inputs_embeds is not None :
291
+ batch , seq_len , hidden_dim = inputs_embeds .shape
292
+ inputs_embeds = inputs_embeds .reshape ([batch * seq_len , hidden_dim ])
293
+
273
294
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
274
295
output_hidden_states = (
275
296
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
276
297
)
277
298
use_cache = use_cache if use_cache is not None else self .config .use_cache
278
299
return_dict = return_dict if return_dict is not None else self .config .use_return_dict
279
300
280
- if inputs_embeds is not None :
281
- batch , seq_len , hidden_dim = inputs_embeds .shape
282
- inputs_embeds = inputs_embeds .reshape ([batch * seq_len , hidden_dim ])
283
-
284
301
if past_key_values is None :
285
302
past_key_values = tuple ([None ] * self .config .num_hidden_layers )
286
303
@@ -502,3 +519,122 @@ def set_state_dict(self, state_dict):
502
519
lm_head_weight = paddle .to_tensor (state_dict ["lm_head.weight" ], dtype = self .lm_head .weight .dtype )
503
520
self .lm_head .weight .set_value (lm_head_weight )
504
521
self .qwen .set_state_dict ({k : state_dict [k ] for k in state_dict .keys ()})
522
+
523
+
524
+ class QWenForQWenVLInferenceModel (QWenForCausalLMInferenceModel ):
525
+ """
526
+ This class is 99% like QWenForCausalLMInferenceModel.
527
+ Used only for QWenVL's second part.
528
+ """
529
+
530
+ # This function corresponds to QWenVL's second part, only used for QWenVL.
531
+ @paddle .no_grad ()
532
+ def generate_text_with_image_features (
533
+ self ,
534
+ input_ids : paddle .Tensor ,
535
+ image_features : paddle .Tensor ,
536
+ img_pos : paddle .Tensor ,
537
+ attention_mask : paddle .Tensor ,
538
+ position_ids = None ,
539
+ penalty_score = None ,
540
+ frequency_score = None ,
541
+ presence_score = None ,
542
+ min_length = None ,
543
+ max_length = None ,
544
+ temperature = None ,
545
+ top_p = None ,
546
+ eos_token_id = None ,
547
+ seq_len_encoder = None ,
548
+ seq_len_decoder = None ,
549
+ step_idx = None ,
550
+ stop_flags = None ,
551
+ tgt_ids = None ,
552
+ tgt_pos = None ,
553
+ tgt_generation_mask = None ,
554
+ pre_ids = None ,
555
+ stop_nums = None ,
556
+ cache_kvs = [],
557
+ inputs_embeds = None ,
558
+ ** generate_kwargs
559
+ ) -> paddle .Tensor :
560
+ inputs_embeds = self .qwen .wte (input_ids )
561
+ inputs_embeds_dtype = inputs_embeds .dtype
562
+ if inputs_embeds_dtype != paddle .float32 :
563
+ inputs_embeds = paddle .cast (inputs_embeds , paddle .float32 )
564
+ image_features = paddle .cast (image_features , paddle .float32 )
565
+
566
+ for idx , (i , image_start_idx , image_end_idx ) in enumerate (img_pos ):
567
+ index = paddle .arange (image_start_idx + 1 , image_end_idx ).unsqueeze (- 1 )
568
+ inputs_embeds [i ] = paddle .scatter (inputs_embeds [i ], index , image_features [idx ])
569
+
570
+ if inputs_embeds_dtype != paddle .float32 :
571
+ inputs_embeds = paddle .cast (inputs_embeds , inputs_embeds_dtype )
572
+
573
+ outputs = self .generate (
574
+ inputs_embeds = inputs_embeds ,
575
+ attention_mask = attention_mask ,
576
+ position_ids = position_ids ,
577
+ penalty_score = penalty_score ,
578
+ frequency_score = frequency_score ,
579
+ presence_score = presence_score ,
580
+ min_length = min_length ,
581
+ max_length = max_length ,
582
+ temperature = temperature ,
583
+ top_p = top_p ,
584
+ eos_token_id = eos_token_id ,
585
+ seq_len_encoder = seq_len_encoder ,
586
+ seq_len_decoder = seq_len_decoder ,
587
+ step_idx = step_idx ,
588
+ stop_flags = stop_flags ,
589
+ tgt_ids = tgt_ids ,
590
+ tgt_pos = tgt_pos ,
591
+ tgt_generation_mask = tgt_generation_mask ,
592
+ pre_ids = pre_ids ,
593
+ stop_nums = stop_nums ,
594
+ cache_kvs = cache_kvs ,
595
+ )
596
+ return outputs
597
+
598
+ # rewrite to_static function in generation_utils.py
599
+ def to_static (self , output_path : str , config : dict ):
600
+ dtype = config .get ("dtype" , paddle .get_default_dtype ())
601
+ cache_kvs_shapes = self .get_cache_kvs_shape (self .config , max_length = config .get ("max_length" , None ))
602
+ input_spec = [
603
+ paddle .static .InputSpec (shape = [None , None ], dtype = "int64" , name = "input_ids" ), # input_ids
604
+ paddle .static .InputSpec (
605
+ shape = [None , None , None ], dtype = "float32" , name = "image_features"
606
+ ), # image_features
607
+ paddle .static .InputSpec (shape = [None , 3 ], dtype = "int64" , name = "img_pos" ), # img_pos
608
+ paddle .static .InputSpec (shape = [None , None ], dtype = dtype , name = "attention_mask" ), # attention_mask
609
+ paddle .static .InputSpec (shape = [None , None ], dtype = "int64" , name = "position_ids" ), # position_ids
610
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "penalty_score" ), # penalty_score
611
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "frequency_score" ), # frequency_score
612
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "presence_score" ), # presence_score
613
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "min_length" ), # min_decode_length
614
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "max_length" ), # max_decode_length
615
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "temperature" ), # temperature
616
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "float32" , name = "top_p" ), # top_p
617
+ paddle .static .InputSpec (shape = [None ], dtype = "int64" , name = "eos_token_id" ), # eos_token_id
618
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int32" , name = "seq_len_encoder" ), # seq_len_encoder
619
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int32" , name = "seq_len_decoder" ), # seq_len_decoder
620
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "step_idx" ), # step_idx
621
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "bool" , name = "stop_flags" ), # stop_flags
622
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "tgt_ids" ), # tgt_ids
623
+ paddle .static .InputSpec (shape = [None , 1 ], dtype = "int64" , name = "tgt_pos" ), # tgt_pos
624
+ paddle .static .InputSpec (
625
+ shape = [None , 1 , 1 , None ], dtype = dtype , name = "tgt_generation_mask"
626
+ ), # tgt_generation_mask
627
+ paddle .static .InputSpec (shape = [None , None ], dtype = "int64" , name = "pre_ids" ), # pre_ids
628
+ paddle .static .InputSpec (shape = [1 ], dtype = "int64" , name = "stop_nums" ), # stop_nums
629
+ [
630
+ paddle .static .InputSpec (
631
+ shape = shape ,
632
+ dtype = dtype ,
633
+ name = "cache_kvs_{}" .format (i ),
634
+ )
635
+ for i , shape in enumerate (cache_kvs_shapes )
636
+ ], # cache_kvs
637
+ ]
638
+
639
+ model = paddle .jit .to_static (self .generate_text_with_image_features , input_spec = input_spec )
640
+ paddle .jit .save (model , output_path , skip_prune_program = True )
0 commit comments