11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- """UT for air_topp_sampling kernel"""
14
+ """UT for air_top_p_sampling kernel"""
15
15
16
16
import subprocess
17
17
import unittest
@@ -36,19 +36,19 @@ def setUp(self):
36
36
release_idx = output .index ("release" ) + 1
37
37
self .nvcc_cuda_version = float (output [release_idx ].split ("," )[0 ])
38
38
39
- def test_air_topp_sampling (self ):
39
+ def test_air_top_p_sampling (self ):
40
40
"""
41
- Check air_topp_sampling output with paddle.tensor.top_p_sampling.
41
+ Check air_top_p_sampling output with paddle.tensor.top_p_sampling.
42
42
"""
43
43
if self .nvcc_cuda_version < 12.0 :
44
- self .skipTest ("air_topp_sampling only support cu12+" )
44
+ self .skipTest ("air_top_p_sampling only support cu12+" )
45
45
bsz = 8
46
46
vocab_size = 103424
47
47
x = paddle .randn ([bsz , vocab_size ])
48
48
x = paddle .nn .functional .softmax (x )
49
49
x = paddle .cast (x , "float32" )
50
50
top_ps = paddle .to_tensor (np .random .uniform (0 , 1 , [bsz ]).astype (np .float32 ))
51
- _ , next_tokens = fastdeploy .model_executor .ops .gpu .air_topp_sampling (
51
+ _ , next_tokens = fastdeploy .model_executor .ops .gpu .air_top_p_sampling (
52
52
x .cuda (), top_ps .cuda (), None , None , seed = 0 , k = 1 , mode = "truncated"
53
53
)
54
54
print (next_tokens )
0 commit comments