File tree Expand file tree Collapse file tree 2 files changed +16
-11
lines changed
python/paddle/fluid/tests/unittests Expand file tree Collapse file tree 2 files changed +16
-11
lines changed Original file line number Diff line number Diff line change @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
#pragma once
15
15
16
+ #include < algorithm>
17
+ #include < iostream>
18
+ #include < iterator>
16
19
#include < random>
20
+ #include < sstream>
17
21
#include < vector>
18
22
#include " paddle/fluid/framework/lod_tensor.h"
19
23
#include " paddle/fluid/framework/op_registry.h"
@@ -34,17 +38,17 @@ class SamplingIdKernel : public framework::OpKernel<T> {
34
38
std::vector<T> ins_vector;
35
39
framework::TensorToVector (*input, context.device_context (), &ins_vector);
36
40
37
- std::vector<int > ids (batch_size);
41
+ std::vector<T > ids (batch_size);
38
42
for (size_t i = 0 ; i < batch_size; ++i) {
39
43
double r = this ->get_rand ();
40
- int id = width - 1 ;
44
+ int idx = width - 1 ;
41
45
for (int j = 0 ; j < width; ++j) {
42
46
if ((r -= ins_vector[i * width + j]) < 0 ) {
43
- id = j;
47
+ idx = j;
44
48
break ;
45
49
}
46
50
}
47
- ids[i] = id ;
51
+ ids[i] = ins_vector[i * width + idx] ;
48
52
}
49
53
50
54
std::vector<int64_t > out_dim;
Original file line number Diff line number Diff line change @@ -25,17 +25,18 @@ def setUp(self):
25
25
self .op_type = "sampling_id"
26
26
self .use_mkldnn = False
27
27
self .init_kernel_type ()
28
- X = np .random .random ((3 , 4 )).astype ('float32' )
29
- self .inputs = {"X" : X }
30
- Y = np .random .random (3 ).astype ('float32' )
31
- self .outputs = {'Out' : Y }
28
+ self . X = np .random .random ((8 , 4 )).astype ('float32' )
29
+ self .inputs = {"X" : self . X }
30
+ self . Y = np .random .random (8 ).astype ('float32' )
31
+ self .outputs = {'Out' : self . Y }
32
32
self .attrs = {'use_mkldnn' : self .use_mkldnn }
33
33
34
34
def test_check_output (self ):
35
- self .check_output ( )
35
+ self .check_output_customized ( self . verify_output )
36
36
37
- def test_check_grad (self ):
38
- self .check_grad (['X' ], 'Out' )
37
+ def verify_output (self , outs ):
38
+ out = np .array (outs [0 ])
39
+ self .assertEqual (len (out ), len (self .Y ))
39
40
40
41
def init_kernel_type (self ):
41
42
pass
You can’t perform that action at this time.
0 commit comments