2727import pytest
2828
2929@pytest .mark .parametrize ('f' , [nlp .model .NCEDense , nlp .model .SparseNCEDense ])
30- def test_nce_loss (f ):
30+ @pytest .mark .parametrize ('cls_dtype' , ['float32' , 'int32' ])
31+ @pytest .mark .parametrize ('count_dtype' , ['float32' , 'int32' ])
32+ def test_nce_loss (f , cls_dtype , count_dtype ):
3133 ctx = mx .cpu ()
3234 batch_size = 2
3335 num_sampled = 3
@@ -40,9 +42,9 @@ def test_nce_loss(f):
4042 trainer = mx .gluon .Trainer (model .collect_params (), 'sgd' )
4143 x = mx .nd .ones ((batch_size , num_hidden ))
4244 y = mx .nd .ones ((batch_size ,))
43- sampled_cls = mx .nd .ones ((num_sampled ,))
44- sampled_cls_cnt = mx .nd .ones ((num_sampled ,))
45- true_cls_cnt = mx .nd .ones ((batch_size ,))
45+ sampled_cls = mx .nd .ones ((num_sampled ,), dtype = cls_dtype )
46+ sampled_cls_cnt = mx .nd .ones ((num_sampled ,), dtype = count_dtype )
47+ true_cls_cnt = mx .nd .ones ((batch_size ,), dtype = count_dtype )
4648 samples = (sampled_cls , sampled_cls_cnt , true_cls_cnt )
4749 with mx .autograd .record ():
4850 pred , new_y = model (x , samples , y )
@@ -53,7 +55,9 @@ def test_nce_loss(f):
5355 mx .nd .waitall ()
5456
5557@pytest .mark .parametrize ('f' , [nlp .model .ISDense , nlp .model .SparseISDense ])
56- def test_is_softmax_loss (f ):
58+ @pytest .mark .parametrize ('cls_dtype' , ['float32' , 'int32' ])
59+ @pytest .mark .parametrize ('count_dtype' , ['float32' , 'int32' ])
60+ def test_is_softmax_loss (f , cls_dtype , count_dtype ):
5761 ctx = mx .cpu ()
5862 batch_size = 2
5963 num_sampled = 3
@@ -66,9 +70,9 @@ def test_is_softmax_loss(f):
6670 trainer = mx .gluon .Trainer (model .collect_params (), 'sgd' )
6771 x = mx .nd .ones ((batch_size , num_hidden ))
6872 y = mx .nd .ones ((batch_size ,))
69- sampled_cls = mx .nd .ones ((num_sampled ,))
70- sampled_cls_cnt = mx .nd .ones ((num_sampled ,))
71- true_cls_cnt = mx .nd .ones ((batch_size ,))
73+ sampled_cls = mx .nd .ones ((num_sampled ,), dtype = cls_dtype )
74+ sampled_cls_cnt = mx .nd .ones ((num_sampled ,), dtype = count_dtype )
75+ true_cls_cnt = mx .nd .ones ((batch_size ,), dtype = count_dtype )
7276 samples = (sampled_cls , sampled_cls_cnt , true_cls_cnt )
7377 with mx .autograd .record ():
7478 pred , new_y = model (x , samples , y )
0 commit comments