Skip to content

Commit cc58a23

Browse files
authored
Fix diversity rate bug (#1477)
* update perf * fix doc and constrains for FasterGeneration * update readme * fix diversity rate bug
1 parent 4e59ce0 commit cc58a23

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/faster/faster_generation/samples/unimo_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def postprocess_response(token_ids, tokenizer):
4141
add_start_token_for_decoding=True,
4242
return_tensors=True,
4343
is_split_into_words=False)
44-
model.eval()
44+
4545
outputs, _ = model.generate(
4646
input_ids=inputs_ids['input_ids'],
4747
token_type_ids=inputs_ids['token_type_ids'],

paddlenlp/ops/faster_transformer/transformer/decoding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ def forward(self,
13811381
_bos_id=bos_token_id,
13821382
_eos_id=eos_token_id,
13831383
_max_out_len=max_out_len,
1384-
_diversity_rate=diversity_rate,
1384+
_diversity_rate=-diversity_rate,
13851385
_unk_id=self._unk_id,
13861386
_mask_id=self._mask_id,
13871387
_temperature=temperature,
@@ -1625,7 +1625,7 @@ def forward(self,
16251625
self.linear_weight, self.linear_bias, self.pos_emb,
16261626
decoding_strategy, beam_size, top_k, top_p, self._n_head,
16271627
int(self._d_model / self._n_head), self._num_decoder_layers,
1628-
bos_token_id, eos_token_id, max_out_len, diversity_rate, rel_len,
1628+
bos_token_id, eos_token_id, max_out_len, -diversity_rate, rel_len,
16291629
alpha, early_stopping)
16301630

16311631
ids = finalize(
@@ -1877,7 +1877,7 @@ def forward(self,
18771877
self.linear_bias, self.pos_emb, trg_word, decoding_strategy,
18781878
beam_size, top_k, top_p, self._n_head,
18791879
int(self._d_model / self._n_head), self._num_decoder_layers,
1880-
bos_token_id, eos_token_id, max_out_len, diversity_rate, rel_len,
1880+
bos_token_id, eos_token_id, max_out_len, -diversity_rate, rel_len,
18811881
alpha, temperature, early_stopping, self._hidden_act)
18821882

18831883
ids = finalize(

0 commit comments

Comments
 (0)