Skip to content

Commit 405b64c

Browse files
authored
fix chatglmv1/2 beam search bug (#7017)
* fix chatglmv1/2 beam search bug * group beam search fix bug
1 parent 97b22f3 commit 405b64c

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

paddlenlp/generation/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
14971497

14981498
def reorder_cache(self, cache, beam_idx):
14991499
cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)
1500+
return cache
15001501

15011502
def beam_search(
15021503
self,
@@ -1626,7 +1627,7 @@ def beam_search(
16261627
cache_name = "cache" if "cache" in model_kwargs else "past_key_values"
16271628
if model_kwargs[cache_name] is not None:
16281629
# reorder the cache
1629-
self.reorder_cache(model_kwargs[cache_name], beam_idx)
1630+
model_kwargs[cache_name] = self.reorder_cache(model_kwargs[cache_name], beam_idx)
16301631

16311632
pred_ids, scores = beam_scorer.finalize(
16321633
input_ids,
@@ -1774,7 +1775,7 @@ def group_beam_search(
17741775
cache_name = "cache" if "cache" in model_kwargs else "past_key_values"
17751776
if model_kwargs[cache_name] is not None:
17761777
# reorder the cache
1777-
self.reorder_cache(model_kwargs[cache_name], beam_idx)
1778+
model_kwargs[cache_name] = self.reorder_cache(model_kwargs[cache_name], reordering_indices)
17781779

17791780
pred_ids, scores = beam_scorer.finalize(
17801781
input_ids,

paddlenlp/transformers/chatglm/modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from paddle import Tensor
2626
from paddle.distributed import fleet
2727
from paddle.distributed.fleet.utils import recompute
28+
from paddle.utils import map_structure
2829

2930
from ...utils.env import CONFIG_NAME
3031
from ...utils.log import logger
@@ -842,6 +843,10 @@ def prepare_inputs_for_generation(
842843
"attention_mask": attention_mask,
843844
}
844845

846+
def reorder_cache(self, cache: paddle.Tensor, beam_idx):
847+
cache = map_structure(lambda x: paddle.index_select(x, beam_idx, axis=1), cache)
848+
return cache
849+
845850
def update_model_kwargs_for_generation(
846851
self,
847852
outputs,

paddlenlp/transformers/chatglm_v2/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ def __init__(self, config: ChatGLMv2Config):
768768

769769
def reorder_cache(self, cache: paddle.Tensor, beam_idx):
770770
cache = map_structure(lambda x: paddle.index_select(x, beam_idx, axis=1), cache)
771+
return cache
771772

772773
def update_model_kwargs_for_generation(
773774
self,

0 commit comments

Comments
 (0)