Skip to content

Commit 433ef99

Browse files
authored
refine sampling code (#1184)
1 parent a34f689 commit 433ef99

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -793,16 +793,16 @@ def sample(self,
793793
model_inputs = self.prepare_inputs_for_generation(input_ids,
794794
**model_kwargs)
795795

796-
if self._decoding_strategy == "sampling" and (top_p == 1.0 and
797-
top_k > 0):
798-
top_p = 0.0
799-
elif self._decoding_strategy == "sampling" and (top_p != 1.0 and
800-
top_k == 0):
801-
top_k = 0
802-
else:
803-
raise ValueError(
804-
"Only topk sampling or topp sampling are supported. " \
805-
"Topk sampling and topp sampling cannot be both applied. ")
796+
if self._decoding_strategy == "sampling":
797+
if top_p == 1.0 and top_k > 0:
798+
top_p = 0.0
799+
elif top_p <= 0.0 and top_k == 0:
800+
raise ValueError(
801+
"Topk sampling or topp sampling must be applied. " \
802+
"Topk sampling and topp sampling cannot be both applied. ")
803+
elif (top_p > 0.0 and top_p < 1.0) and top_k > 0:
804+
raise ValueError(
805+
"Topk sampling and topp sampling cannot be both applied. ")
806806

807807
return self.forward(
808808
model_inputs=model_inputs,
@@ -956,16 +956,16 @@ def sample(self,
956956
model_inputs = self.prepare_inputs_for_generation(input_ids,
957957
**model_kwargs)
958958

959-
if self._decoding_strategy == "sampling" and (top_p == 1.0 and
960-
top_k > 0):
961-
top_p = 0.0
962-
elif self._decoding_strategy == "sampling" and (top_p != 1.0 and
963-
top_k == 0):
964-
top_k = 0
965-
else:
966-
raise ValueError(
967-
"Only topk sampling or topp sampling are supported. " \
968-
"Topk sampling and topp sampling cannot be both applied. ")
959+
if self._decoding_strategy == "sampling":
960+
if top_p == 1.0 and top_k > 0:
961+
top_p = 0.0
962+
elif top_p <= 0.0 and top_k == 0:
963+
raise ValueError(
964+
"Topk sampling or topp sampling must be applied. " \
965+
"Topk sampling and topp sampling cannot be both applied. ")
966+
elif (top_p > 0.0 and top_p < 1.0) and top_k > 0:
967+
raise ValueError(
968+
"Topk sampling and topp sampling cannot be both applied. ")
969969

970970
return self.forward(
971971
model_inputs=model_inputs,

0 commit comments

Comments
 (0)