File tree Expand file tree Collapse file tree 2 files changed +14
-4
lines changed
python/paddle/fluid/tests/unittests Expand file tree Collapse file tree 2 files changed +14
-4
lines changed Original file line number Diff line number Diff line change @@ -53,7 +53,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
53
53
static_cast <T>(context.Attr <float >(" min" )),
54
54
static_cast <T>(context.Attr <float >(" max" )));
55
55
56
- std::vector<T > ids (batch_size);
56
+ std::vector<int64_t > ids (batch_size);
57
57
for (int i = 0 ; i < batch_size; ++i) {
58
58
T r = dist (engine);
59
59
int idx = width - 1 ;
@@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
63
63
break ;
64
64
}
65
65
}
66
- ids[i] = ins_vector[ idx] ;
66
+ ids[i] = int64_t ( idx) ;
67
67
}
68
68
69
69
std::vector<int64_t > out_dim;
Original file line number Diff line number Diff line change @@ -25,9 +25,9 @@ def setUp(self):
25
25
self .op_type = "sampling_id"
26
26
self .use_mkldnn = False
27
27
self .init_kernel_type ()
28
- self .X = np .random .random ((8 , 4 )).astype ('float32' )
28
+ self .X = np .random .random ((100 , 10 )).astype ('float32' )
29
29
self .inputs = {"X" : self .X }
30
- self .Y = np .random .random (8 ).astype ('float32 ' )
30
+ self .Y = np .random .random (100 ).astype ('int64 ' )
31
31
self .outputs = {'Out' : self .Y }
32
32
self .attrs = {'max' : 1.0 , 'min' : 0.0 , 'seed' : 1 }
33
33
@@ -36,6 +36,16 @@ def test_check_output(self):
36
36
y1 = self .out
37
37
self .check_output_customized (self .verify_output )
38
38
y2 = self .out
39
+
40
+ # check dtype
41
+ assert y1 .dtype == np .int64
42
+ assert y2 .dtype == np .int64
43
+
44
+ # check output is index ids of inputs
45
+ inputs_ids = np .arange (self .X .shape [1 ])
46
+ assert np .isin (y1 , inputs_ids ).all ()
47
+ assert np .isin (y2 , inputs_ids ).all ()
48
+
39
49
self .assertTrue (np .array_equal (y1 , y2 ))
40
50
self .assertEqual (len (y1 ), len (self .Y ))
41
51
You can’t perform that action at this time.
0 commit comments