@@ -793,16 +793,16 @@ def sample(self,
793
793
model_inputs = self .prepare_inputs_for_generation (input_ids ,
794
794
** model_kwargs )
795
795
796
- if self ._decoding_strategy == "sampling" and ( top_p == 1.0 and
797
- top_k > 0 ) :
798
- top_p = 0.0
799
- elif self . _decoding_strategy == "sampling" and ( top_p != 1 .0 and
800
- top_k == 0 ):
801
- top_k = 0
802
- else :
803
- raise ValueError (
804
- "Only topk sampling or topp sampling are supported. " \
805
- "Topk sampling and topp sampling cannot be both applied. " )
796
+ if self ._decoding_strategy == "sampling" :
797
+ if top_p == 1.0 and top_k > 0 :
798
+ top_p = 0.0
799
+ elif top_p <= 0 .0 and top_k == 0 :
800
+ raise ValueError (
801
+ "Topk sampling or topp sampling must be applied. " \
802
+ "Topk sampling and topp sampling cannot be both applied. " )
803
+ elif ( top_p > 0.0 and top_p < 1.0 ) and top_k > 0 :
804
+ raise ValueError (
805
+ "Topk sampling and topp sampling cannot be both applied. " )
806
806
807
807
return self .forward (
808
808
model_inputs = model_inputs ,
@@ -956,16 +956,16 @@ def sample(self,
956
956
model_inputs = self .prepare_inputs_for_generation (input_ids ,
957
957
** model_kwargs )
958
958
959
- if self ._decoding_strategy == "sampling" and ( top_p == 1.0 and
960
- top_k > 0 ) :
961
- top_p = 0.0
962
- elif self . _decoding_strategy == "sampling" and ( top_p != 1 .0 and
963
- top_k == 0 ):
964
- top_k = 0
965
- else :
966
- raise ValueError (
967
- "Only topk sampling or topp sampling are supported. " \
968
- "Topk sampling and topp sampling cannot be both applied. " )
959
+ if self ._decoding_strategy == "sampling" :
960
+ if top_p == 1.0 and top_k > 0 :
961
+ top_p = 0.0
962
+ elif top_p <= 0 .0 and top_k == 0 :
963
+ raise ValueError (
964
+ "Topk sampling or topp sampling must be applied. " \
965
+ "Topk sampling and topp sampling cannot be both applied. " )
966
+ elif ( top_p > 0.0 and top_p < 1.0 ) and top_k > 0 :
967
+ raise ValueError (
968
+ "Topk sampling and topp sampling cannot be both applied. " )
969
969
970
970
return self .forward (
971
971
model_inputs = model_inputs ,
0 commit comments