|
20 | 20 |
|
21 | 21 | import paddle
|
22 | 22 | import paddle.distributed as dist
|
23 |
| -import paddle.nn as nn |
24 | 23 | import paddle.nn.functional as F
|
25 | 24 | from paddle import Tensor
|
26 |
| -from paddle.common_ops_import import convert_dtype |
27 | 25 | from paddle.utils import map_structure
|
28 | 26 |
|
29 |
| -from ..transformers.model_outputs import ModelOutput |
| 27 | +from ..transformers.model_outputs import CausalLMOutputWithPast, ModelOutput |
30 | 28 | from ..transformers.utils import get_scale_by_dtype
|
31 | 29 | from ..utils.log import logger
|
32 | 30 | from ..utils.masking_utils import _expand_2d_mask, _make_causal_mask
|
@@ -493,61 +491,38 @@ def expand_inputs_for_generation(input_ids, expand_size, attention_mask=None, **
|
493 | 491 |
|
494 | 492 | @staticmethod
|
495 | 493 | def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
496 |
| - # Update the model inputs during generation. |
497 |
| - # Note that If `token_type_ids` and `attention_mask` in `model_kwargs` |
498 |
| - # and they contain pad value, the result vectors updated by this method |
499 |
| - # may be different from expected. In this case, you need to rewrite the |
500 |
| - # method. |
| 494 | + """ |
| 495 | + Updates model kwargs for generation. |
| 496 | +
|
| 497 | + Args: |
| 498 | + outputs (Any): Model outputs. |
| 499 | + model_kwargs (dict): Current model kwargs. |
| 500 | + is_encoder_decoder (bool): Whether using encoder-decoder architecture. |
501 | 501 |
|
| 502 | + Returns: |
| 503 | + dict: Updated model kwargs. |
| 504 | + """ |
502 | 505 | # update cache
|
503 | 506 | if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor):
|
504 |
| - model_kwargs["cache"] = outputs[1] |
505 | 507 | model_kwargs["past_key_values"] = outputs[1]
|
506 | 508 |
|
507 |
| - if isinstance(outputs, ModelOutput) and "past_key_values" in outputs: |
508 |
| - model_kwargs["cache"] = outputs.past_key_values |
| 509 | + if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs: |
509 | 510 | model_kwargs["past_key_values"] = outputs.past_key_values
|
510 | 511 |
|
511 | 512 | # update token_type_ids with last value
|
512 | 513 | if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None:
|
513 | 514 | token_type_ids = model_kwargs["token_type_ids"]
|
514 | 515 | model_kwargs["token_type_ids"] = paddle.concat([token_type_ids, token_type_ids[:, -1:]], axis=-1)
|
515 |
| - |
516 |
| - # update position_ids |
517 |
| - if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: |
518 |
| - position_ids = model_kwargs["position_ids"] |
519 |
| - model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) |
520 |
| - |
521 |
| - # update attention_mask |
522 |
| - if not is_encoder_decoder and "attention_mask" in model_kwargs: |
| 516 | + if not is_encoder_decoder and model_kwargs.get("attention_mask", None) is not None: |
| 517 | + # update attention mask |
523 | 518 | attention_mask = model_kwargs["attention_mask"]
|
524 |
| - # nn.Pad2D don't support the data type `bool` |
525 |
| - if convert_dtype(attention_mask.dtype) == "bool": |
526 |
| - attention_mask = paddle.cast(attention_mask, "int64") |
527 |
| - if len(attention_mask.shape) == 4: |
528 |
| - cur_device = paddle.get_device() |
529 |
| - if cur_device.split(":")[0] == "npu": |
530 |
| - attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(attention_mask) |
531 |
| - attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask) |
532 |
| - else: |
533 |
| - attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(attention_mask) |
534 |
| - attention_mask = nn.Pad2D([0, 1, 0, 0], value=get_scale_by_dtype(return_positive=False))( |
535 |
| - attention_mask |
536 |
| - ) |
537 |
| - |
538 |
| - dtype = convert_dtype(attention_mask.dtype) |
539 |
| - if "int" in dtype: |
540 |
| - attention_mask[:, :, -1, -1] = 1 |
541 |
| - elif "float" in dtype: |
542 |
| - attention_mask[:, :, -1, -1] = 0.0 |
543 |
| - else: |
544 |
| - raise ValueError("The data type of input `attention_mask` must " "be bool, int or float") |
545 |
| - else: |
546 |
| - attention_mask = paddle.concat( |
547 |
| - [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype="int64")], axis=-1 |
548 |
| - ) |
549 |
| - model_kwargs["attention_mask"] = attention_mask |
550 |
| - |
| 519 | + model_kwargs["attention_mask"] = paddle.concat( |
| 520 | + [ |
| 521 | + attention_mask, |
| 522 | + paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype), |
| 523 | + ], |
| 524 | + axis=-1, |
| 525 | + ) |
551 | 526 | # update role_ids
|
552 | 527 | if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
|
553 | 528 | role_ids = model_kwargs["role_ids"]
|
@@ -611,11 +586,63 @@ def get_decoder_start_token_id(self, decoder_start_token_id=None, bos_token_id=N
|
611 | 586 | "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
612 | 587 | )
|
613 | 588 |
|
614 |
| - def prepare_inputs_for_generation(self, input_ids, **kwargs): |
615 |
| - # Implement in subclasses for custom behavior to prepare inputs in the |
616 |
| - # generate method. |
| 589 | + def prepare_inputs_for_generation( |
| 590 | + self, |
| 591 | + input_ids, |
| 592 | + use_cache=True, |
| 593 | + past_key_values=None, |
| 594 | + inputs_embeds=None, |
| 595 | + **kwargs, |
| 596 | + ): |
| 597 | + """Prepares model inputs for generation in PaddlePaddle models. |
| 598 | +
|
| 599 | + Args: |
| 600 | + input_ids (paddle.Tensor): |
| 601 | + The input token IDs with shape [batch_size, sequence_length]. |
| 602 | + use_cache (bool, optional): |
| 603 | + Whether to use cached key-value states for faster generation. |
| 604 | + Defaults to False. |
| 605 | + past_key_values (Optional[Tuple[paddle.Tensor]]): |
| 606 | + Cached past key-value states from previous generation steps. |
| 607 | + If provided, the input_ids will be truncated to only keep the last token. |
| 608 | + inputs_embeds (Optional[paddle.Tensor]): |
| 609 | + Precomputed embeddings instead of token IDs. |
| 610 | + Only used in the first generation step when past_key_values is None. |
| 611 | + **kwargs: |
| 612 | + Additional keyword arguments including: |
| 613 | + - attention_mask (paddle.Tensor): Attention mask tensor |
| 614 | +
|
| 615 | + Returns: |
| 616 | + Dict[str, Union[paddle.Tensor, bool, Dict]]: |
| 617 | + A dictionary containing: |
| 618 | + - "input_ids" or "inputs_embeds": The main input tensors |
| 619 | + - "past_key_values": The cached key-value states |
| 620 | + - "use_cache": Flag indicating whether to use caching |
| 621 | + - "attention_mask": The attention mask tensor (if provided) |
| 622 | + - "return_dict": Always set to True for consistent output format |
| 623 | +
|
| 624 | + """ |
| 625 | + if past_key_values: |
| 626 | + input_ids = input_ids[:, -1:] |
| 627 | + |
| 628 | + attention_mask = kwargs.get("attention_mask", None) |
| 629 | + |
| 630 | + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step |
| 631 | + if inputs_embeds is not None and past_key_values is None: |
| 632 | + model_inputs = {"inputs_embeds": inputs_embeds} |
| 633 | + else: |
| 634 | + model_inputs = {"input_ids": input_ids} |
| 635 | + |
| 636 | + model_inputs.update( |
| 637 | + { |
| 638 | + "past_key_values": past_key_values, |
| 639 | + "use_cache": use_cache, |
| 640 | + "attention_mask": attention_mask, |
| 641 | + "return_dict": True, |
| 642 | + } |
| 643 | + ) |
617 | 644 |
|
618 |
| - return {"input_ids": input_ids} |
| 645 | + return model_inputs |
619 | 646 |
|
620 | 647 | def adjust_logits_during_generation(self, logits):
|
621 | 648 | # Implement in subclasses for custom behavior to adjust the logits in
|
|
0 commit comments